1
0
Fork 0
mirror of https://github.com/Luzifer/mondash.git synced 2024-12-22 20:11:18 +00:00

Updated dependencies

This commit is contained in:
Knut Ahlers 2016-03-27 20:46:08 +02:00
parent 8a3cf7db08
commit d0d5587cd9
231 changed files with 15913 additions and 16797 deletions

104
Godeps/Godeps.json generated
View file

@ -6,52 +6,112 @@
{
"ImportPath": "code.google.com/p/go.text/transform",
"Comment": "null-104",
"Rev": "7ab6d3749429e2c62c53eecadf3e4ec5e715c11b"
"Rev": "'7ab6d3749429e2c62c53eecadf3e4ec5e715c11b'"
},
{
"ImportPath": "code.google.com/p/go.text/unicode/norm",
"Comment": "null-104",
"Rev": "7ab6d3749429e2c62c53eecadf3e4ec5e715c11b"
"Rev": "'7ab6d3749429e2c62c53eecadf3e4ec5e715c11b'"
},
{
"ImportPath": "github.com/Luzifer/rconfig",
"Comment": "v1.0.3-2-g2677653",
"Rev": "26776536e61487fdffbd3ce87f827177a5903f98"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/awserr",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/awsutil",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/corehandlers",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/defaults",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/ec2metadata",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/request",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/service",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/service/serviceinfo",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/endpoints",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/query",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/query/queryutil",
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/rest",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/restxml",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/signer/v4",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/s3",
"Comment": "v0.6.4-2-g168a70b",
"Rev": "168a70b9c21a4f60166d7925b690356605907adb"
"Comment": "v0.9.10",
"Rev": "661aeb3339ad9bd5fd420752ebf18a651d40413e"
},
{
"ImportPath": "github.com/extemporalgenome/slug",
@ -85,15 +145,19 @@
},
{
"ImportPath": "github.com/shurcooL/sanitized_anchor_name",
"Rev": "9a8b7d4e8f347bfa230879db9d7d4e4d9e19f962"
"Rev": "244f5ac324cb97e1987ef901a0081a77bfd8e845"
},
{
"ImportPath": "github.com/spf13/pflag",
"Rev": "f1e68ce945b0710375b5cccee37318a3d13fdf8c"
"Rev": "b084184666e02084b8ccb9b704bf0d79c466eb1d"
},
{
"ImportPath": "github.com/vaughan0/go-ini",
"Rev": "a98ad7ee00ec53921f08832bc06ecf7fd600e6a1"
},
{
"ImportPath": "gopkg.in/yaml.v2",
"Rev": "53feefa2559fb8dfa8d81baad31be332c97d6c77"
}
]
}

27
vendor/code.google.com/p/go.text/LICENSE generated vendored Normal file
View file

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/code.google.com/p/go.text/PATENTS generated vendored Normal file
View file

@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

8
vendor/github.com/Luzifer/rconfig/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,8 @@
language: go
go:
- 1.4
- 1.5
- tip
script: go test -v -race -cover ./...

13
vendor/github.com/Luzifer/rconfig/LICENSE generated vendored Normal file
View file

@ -0,0 +1,13 @@
Copyright 2015 Knut Ahlers <knut@ahlers.me>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

94
vendor/github.com/Luzifer/rconfig/README.md generated vendored Normal file
View file

@ -0,0 +1,94 @@
[![Build Status](https://travis-ci.org/Luzifer/rconfig.svg?branch=master)](https://travis-ci.org/Luzifer/rconfig)
[![License: Apache v2.0](https://badge.luzifer.io/v1/badge?color=5d79b5&title=license&text=Apache+v2.0)](http://www.apache.org/licenses/LICENSE-2.0)
[![Documentation](https://badge.luzifer.io/v1/badge?title=godoc&text=reference)](https://godoc.org/github.com/Luzifer/rconfig)
[![Go Report](http://goreportcard.com/badge/Luzifer/rconfig)](http://goreportcard.com/report/Luzifer/rconfig)
## Description
> Package rconfig implements a CLI configuration reader with struct-embedded defaults, environment variables and posix compatible flag parsing using the [pflag](https://github.com/spf13/pflag) library.
## Installation
Install by running:
```
go get -u github.com/Luzifer/rconfig
```
OR fetch a specific version:
```
go get -u gopkg.in/luzifer/rconfig.v1
```
Run tests by running:
```
go test -v -race -cover github.com/Luzifer/rconfig
```
## Usage
As a first step define a struct holding your configuration:
```go
type config struct {
Username string `default:"unknown" flag:"user" description:"Your name"`
Details struct {
Age int `default:"25" flag:"age" env:"age" description:"Your age"`
}
}
```
Next create an instance of that struct and let `rconfig` fill that config:
```go
var cfg config
func init() {
cfg = config{}
rconfig.Parse(&cfg)
}
```
You're ready to access your configuration:
```go
func main() {
fmt.Printf("Hello %s, happy birthday for your %dth birthday.",
cfg.Username,
cfg.Details.Age)
}
```
### Provide variable defaults by using a file
Given you have a file `~/.myapp.yml` containing some secrets or usernames (for the example below username is assumed to be "luzifer") as a default configuration for your application you can use this source code to load the defaults from that file using the `vardefault` tag in your configuration struct.
The order of the directives (lower number = higher precedence):
1. Flags provided in command line
1. Environment variables
1. Variable defaults (`vardefault` tag in the struct)
1. `default` tag in the struct
```go
type config struct {
Username string `vardefault:"username" flag:"username" description:"Your username"`
}
var cfg = config{}
func init() {
rconfig.SetVariableDefaults(rconfig.VarDefaultsFromYAMLFile("~/.myapp.yml"))
rconfig.Parse(&cfg)
}
func main() {
fmt.Printf("Username = %s", cfg.Username)
// Output: Username = luzifer
}
```
## More info
You can see the full reference documentation of the rconfig package [at godoc.org](https://godoc.org/github.com/Luzifer/rconfig), or through go's standard documentation system by running `godoc -http=:6060` and browsing to [http://localhost:6060/pkg/github.com/Luzifer/rconfig](http://localhost:6060/pkg/github.com/Luzifer/rconfig) after installation.

314
vendor/github.com/Luzifer/rconfig/config.go generated vendored Normal file
View file

@ -0,0 +1,314 @@
// Package rconfig implements a CLI configuration reader with struct-embedded
// defaults, environment variables and posix compatible flag parsing using
// the pflag library.
package rconfig
import (
"errors"
"fmt"
"os"
"reflect"
"strconv"
"strings"
"github.com/spf13/pflag"
)
var (
fs *pflag.FlagSet
variableDefaults map[string]string
)
func init() {
variableDefaults = make(map[string]string)
}
// Parse takes the pointer to a struct filled with variables which should be read
// from ENV, default or flag. The precedence in this is flag > ENV > default. So
// if a flag is specified on the CLI it will overwrite the ENV and otherwise ENV
// overwrites the default specified.
//
// For your configuration struct you can use the following struct-tags to control
// the behavior of rconfig:
//
// default: Set a default value
// vardefault: Read the default value from the variable defaults
// env: Read the value from this environment variable
// flag: Flag to read in format "long,short" (for example "listen,l")
// description: A help text for Usage output to guide your users
//
// The format you need to specify those values you can see in the example to this
// function.
//
func Parse(config interface{}) error {
return parse(config, nil)
}
// Args returns the non-flag command-line arguments.
func Args() []string {
return fs.Args()
}
// Usage prints a basic usage with the corresponding defaults for the flags to
// os.Stdout. The defaults are derived from the `default` struct-tag and the ENV.
func Usage() {
if fs != nil && fs.Parsed() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
fs.PrintDefaults()
}
}
// SetVariableDefaults presets the parser with a map of default values to be used
// when specifying the vardefault tag
func SetVariableDefaults(defaults map[string]string) {
variableDefaults = defaults
}
func parse(in interface{}, args []string) error {
if args == nil {
args = os.Args
}
fs = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
if err := execTags(in, fs); err != nil {
return err
}
return fs.Parse(args)
}
func execTags(in interface{}, fs *pflag.FlagSet) error {
if reflect.TypeOf(in).Kind() != reflect.Ptr {
return errors.New("Calling parser with non-pointer")
}
if reflect.ValueOf(in).Elem().Kind() != reflect.Struct {
return errors.New("Calling parser with pointer to non-struct")
}
st := reflect.ValueOf(in).Elem()
for i := 0; i < st.NumField(); i++ {
valField := st.Field(i)
typeField := st.Type().Field(i)
if typeField.Tag.Get("default") == "" && typeField.Tag.Get("env") == "" && typeField.Tag.Get("flag") == "" && typeField.Type.Kind() != reflect.Struct {
// None of our supported tags is present and it's not a sub-struct
continue
}
value := varDefault(typeField.Tag.Get("vardefault"), typeField.Tag.Get("default"))
value = envDefault(typeField.Tag.Get("env"), value)
parts := strings.Split(typeField.Tag.Get("flag"), ",")
switch typeField.Type.Kind() {
case reflect.String:
if typeField.Tag.Get("flag") != "" {
if len(parts) == 1 {
fs.StringVar(valField.Addr().Interface().(*string), parts[0], value, typeField.Tag.Get("description"))
} else {
fs.StringVarP(valField.Addr().Interface().(*string), parts[0], parts[1], value, typeField.Tag.Get("description"))
}
} else {
valField.SetString(value)
}
case reflect.Bool:
v := value == "true"
if typeField.Tag.Get("flag") != "" {
if len(parts) == 1 {
fs.BoolVar(valField.Addr().Interface().(*bool), parts[0], v, typeField.Tag.Get("description"))
} else {
fs.BoolVarP(valField.Addr().Interface().(*bool), parts[0], parts[1], v, typeField.Tag.Get("description"))
}
} else {
valField.SetBool(v)
}
case reflect.Int, reflect.Int8, reflect.Int32, reflect.Int64:
vt, err := strconv.ParseInt(value, 10, 64)
if err != nil {
if value == "" {
vt = 0
} else {
return err
}
}
if typeField.Tag.Get("flag") != "" {
registerFlagInt(typeField.Type.Kind(), fs, valField.Addr().Interface(), parts, vt, typeField.Tag.Get("description"))
} else {
valField.SetInt(vt)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vt, err := strconv.ParseUint(value, 10, 64)
if err != nil {
if value == "" {
vt = 0
} else {
return err
}
}
if typeField.Tag.Get("flag") != "" {
registerFlagUint(typeField.Type.Kind(), fs, valField.Addr().Interface(), parts, vt, typeField.Tag.Get("description"))
} else {
valField.SetUint(vt)
}
case reflect.Float32, reflect.Float64:
vt, err := strconv.ParseFloat(value, 64)
if err != nil {
if value == "" {
vt = 0.0
} else {
return err
}
}
if typeField.Tag.Get("flag") != "" {
registerFlagFloat(typeField.Type.Kind(), fs, valField.Addr().Interface(), parts, vt, typeField.Tag.Get("description"))
} else {
valField.SetFloat(vt)
}
case reflect.Struct:
if err := execTags(valField.Addr().Interface(), fs); err != nil {
return err
}
case reflect.Slice:
switch typeField.Type.Elem().Kind() {
case reflect.Int:
def := []int{}
for _, v := range strings.Split(value, ",") {
it, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
if err != nil {
return err
}
def = append(def, int(it))
}
if len(parts) == 1 {
fs.IntSliceVar(valField.Addr().Interface().(*[]int), parts[0], def, typeField.Tag.Get("description"))
} else {
fs.IntSliceVarP(valField.Addr().Interface().(*[]int), parts[0], parts[1], def, typeField.Tag.Get("description"))
}
case reflect.String:
del := typeField.Tag.Get("delimiter")
if len(del) == 0 {
del = ","
}
def := strings.Split(value, del)
if len(parts) == 1 {
fs.StringSliceVar(valField.Addr().Interface().(*[]string), parts[0], def, typeField.Tag.Get("description"))
} else {
fs.StringSliceVarP(valField.Addr().Interface().(*[]string), parts[0], parts[1], def, typeField.Tag.Get("description"))
}
}
}
}
return nil
}
func registerFlagFloat(t reflect.Kind, fs *pflag.FlagSet, field interface{}, parts []string, vt float64, desc string) {
switch t {
case reflect.Float32:
if len(parts) == 1 {
fs.Float32Var(field.(*float32), parts[0], float32(vt), desc)
} else {
fs.Float32VarP(field.(*float32), parts[0], parts[1], float32(vt), desc)
}
case reflect.Float64:
if len(parts) == 1 {
fs.Float64Var(field.(*float64), parts[0], float64(vt), desc)
} else {
fs.Float64VarP(field.(*float64), parts[0], parts[1], float64(vt), desc)
}
}
}
func registerFlagInt(t reflect.Kind, fs *pflag.FlagSet, field interface{}, parts []string, vt int64, desc string) {
switch t {
case reflect.Int:
if len(parts) == 1 {
fs.IntVar(field.(*int), parts[0], int(vt), desc)
} else {
fs.IntVarP(field.(*int), parts[0], parts[1], int(vt), desc)
}
case reflect.Int8:
if len(parts) == 1 {
fs.Int8Var(field.(*int8), parts[0], int8(vt), desc)
} else {
fs.Int8VarP(field.(*int8), parts[0], parts[1], int8(vt), desc)
}
case reflect.Int32:
if len(parts) == 1 {
fs.Int32Var(field.(*int32), parts[0], int32(vt), desc)
} else {
fs.Int32VarP(field.(*int32), parts[0], parts[1], int32(vt), desc)
}
case reflect.Int64:
if len(parts) == 1 {
fs.Int64Var(field.(*int64), parts[0], int64(vt), desc)
} else {
fs.Int64VarP(field.(*int64), parts[0], parts[1], int64(vt), desc)
}
}
}
func registerFlagUint(t reflect.Kind, fs *pflag.FlagSet, field interface{}, parts []string, vt uint64, desc string) {
switch t {
case reflect.Uint:
if len(parts) == 1 {
fs.UintVar(field.(*uint), parts[0], uint(vt), desc)
} else {
fs.UintVarP(field.(*uint), parts[0], parts[1], uint(vt), desc)
}
case reflect.Uint8:
if len(parts) == 1 {
fs.Uint8Var(field.(*uint8), parts[0], uint8(vt), desc)
} else {
fs.Uint8VarP(field.(*uint8), parts[0], parts[1], uint8(vt), desc)
}
case reflect.Uint16:
if len(parts) == 1 {
fs.Uint16Var(field.(*uint16), parts[0], uint16(vt), desc)
} else {
fs.Uint16VarP(field.(*uint16), parts[0], parts[1], uint16(vt), desc)
}
case reflect.Uint32:
if len(parts) == 1 {
fs.Uint32Var(field.(*uint32), parts[0], uint32(vt), desc)
} else {
fs.Uint32VarP(field.(*uint32), parts[0], parts[1], uint32(vt), desc)
}
case reflect.Uint64:
if len(parts) == 1 {
fs.Uint64Var(field.(*uint64), parts[0], uint64(vt), desc)
} else {
fs.Uint64VarP(field.(*uint64), parts[0], parts[1], uint64(vt), desc)
}
}
}
func envDefault(env, def string) string {
value := def
if env != "" {
if e := os.Getenv(env); e != "" {
value = e
}
}
return value
}
func varDefault(name, def string) string {
value := def
if name != "" {
if v, ok := variableDefaults[name]; ok {
value = v
}
}
return value
}

View file

@ -0,0 +1,27 @@
package rconfig
import (
"io/ioutil"
"gopkg.in/yaml.v2"
)
// VarDefaultsFromYAMLFile reads contents of a file and calls VarDefaultsFromYAML
func VarDefaultsFromYAMLFile(filename string) map[string]string {
data, err := ioutil.ReadFile(filename)
if err != nil {
return make(map[string]string)
}
return VarDefaultsFromYAML(data)
}
// VarDefaultsFromYAML creates a vardefaults map from YAML raw data
func VarDefaultsFromYAML(in []byte) map[string]string {
out := make(map[string]string)
err := yaml.Unmarshal(in, &out)
if err != nil {
return make(map[string]string)
}
return out
}

202
vendor/github.com/aws/aws-sdk-go/LICENSE.txt generated vendored Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

3
vendor/github.com/aws/aws-sdk-go/NOTICE.txt generated vendored Normal file
View file

@ -0,0 +1,3 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.

View file

@ -10,23 +10,23 @@ package awserr
//
// Example:
//
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if awsErr, ok := err.(awserr.Error); ok {
// // Get error details
// log.Println("Error:", err.Code(), err.Message())
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if awsErr, ok := err.(awserr.Error); ok {
// // Get error details
// log.Println("Error:", err.Code(), err.Message())
//
// Prints out full error message, including original error if there was one.
// log.Println("Error:", err.Error())
// // Prints out full error message, including original error if there was one.
// log.Println("Error:", err.Error())
//
// // Get original error
// if origErr := err.Err(); origErr != nil {
// // operate on original error.
// }
// } else {
// fmt.Println(err.Error())
// }
// }
// // Get original error
// if origErr := err.Err(); origErr != nil {
// // operate on original error.
// }
// } else {
// fmt.Println(err.Error())
// }
// }
//
type Error interface {
// Satisfy the generic error interface.

View file

@ -113,7 +113,7 @@ func newRequestError(err Error, statusCode int, requestID string) *requestError
// Error returns the string representation of the error.
// Satisfies the error interface.
func (r requestError) Error() string {
extra := fmt.Sprintf("status code: %d, request id: [%s]",
extra := fmt.Sprintf("status code: %d, request id: %s",
r.statusCode, r.requestID)
return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}

View file

@ -70,12 +70,20 @@ func rcopy(dst, src reflect.Value, root bool) {
}
}
case reflect.Slice:
if src.IsNil() {
break
}
s := reflect.MakeSlice(src.Type(), src.Len(), src.Cap())
dst.Set(s)
for i := 0; i < src.Len(); i++ {
rcopy(dst.Index(i), src.Index(i), false)
}
case reflect.Map:
if src.IsNil() {
break
}
s := reflect.MakeMap(src.Type())
dst.Set(s)
for _, k := range src.MapKeys() {

View file

@ -1,193 +0,0 @@
package awsutil_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Print the result
fmt.Println(awsutil.StringValue(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}
func TestCopy(t *testing.T) {
type Foo struct {
A int
B []*string
C map[string]*int
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &Foo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
// But pointers are not!
str3 := "nothello"
int3 := 57
f2.A = 100
f2.B[0] = &str3
f2.C["B"] = &int3
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
}
func TestCopyIgnoreNilMembers(t *testing.T) {
type Foo struct {
A *string
}
f := &Foo{}
assert.Nil(t, f.A)
var f2 Foo
awsutil.Copy(&f2, f)
assert.Nil(t, f2.A)
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
assert.Nil(t, f3.A)
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
}
func TestCopyReader(t *testing.T) {
var buf io.Reader = bytes.NewReader([]byte("hello world"))
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
}
func TestCopyDifferentStructs(t *testing.T) {
type SrcFoo struct {
A int
B []*string
C map[string]*int
SrcUnique string
SameNameDiffType int
}
type DstFoo struct {
A int
B []*string
C map[string]*int
DstUnique int
SameNameDiffType string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &SrcFoo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
SrcUnique: "unique",
SameNameDiffType: 1,
}
// Do the copy
var f2 DstFoo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, "unique", f1.SrcUnique)
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
}
func ExampleCopyOf() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
v := awsutil.CopyOf(f1)
var f2 *Foo = v.(*Foo)
// Print the result
fmt.Println(awsutil.StringValue(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}

View file

@ -81,6 +81,12 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
value = reflect.Indirect(value)
}
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() {
value = reflect.ValueOf(nil)
}
}
if value.IsValid() {
nextvals = append(nextvals, value)
}
@ -118,6 +124,12 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
}
value = reflect.Indirect(value.Index(i))
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() {
value = reflect.ValueOf(nil)
}
}
if value.IsValid() {
nextvals = append(nextvals, value)
}

View file

@ -1,65 +0,0 @@
package awsutil_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
A []Struct
z []Struct
B *Struct
D *Struct
C string
}
var data = Struct{
A: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
z: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}
func TestValueAtPathSuccess(t *testing.T) {
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "C"))
assert.Equal(t, []interface{}{"value1"}, awsutil.ValuesAtPath(data, "A[0].C"))
assert.Equal(t, []interface{}{"value2"}, awsutil.ValuesAtPath(data, "A[1].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[2].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtAnyPath(data, "a[2].c"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[-1].C"))
assert.Equal(t, []interface{}{"value1", "value2", "value3"}, awsutil.ValuesAtPath(data, "A[].C"))
assert.Equal(t, []interface{}{"terminal"}, awsutil.ValuesAtPath(data, "B . B . C"))
assert.Equal(t, []interface{}{"terminal", "terminal2"}, awsutil.ValuesAtPath(data, "B.*.C"))
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "A.D.X || C"))
}
func TestValueAtPathFailure(t *testing.T) {
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "C.x"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, ".x"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "X.Y.Z"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[100].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[3].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "B.B.C.Z"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "z[-1].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(nil, "A.B.C"))
}
func TestSetValueAtPathSuccess(t *testing.T) {
var s Struct
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
var s2 Struct
awsutil.SetValueAtAnyPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
}

View file

@ -8,16 +8,16 @@ import (
"strings"
)
// StringValue returns the string representation of a value.
func StringValue(i interface{}) string {
// Prettify returns the string representation of a value.
func Prettify(i interface{}) string {
var buf bytes.Buffer
stringValue(reflect.ValueOf(i), 0, &buf)
prettify(reflect.ValueOf(i), 0, &buf)
return buf.String()
}
// stringValue will recursively walk value v to build a textual
// prettify will recursively walk value v to build a textual
// representation of the value.
func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
@ -52,7 +52,7 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
stringValue(val, indent+2, buf)
prettify(val, indent+2, buf)
if i < len(names)-1 {
buf.WriteString(",\n")
@ -68,7 +68,7 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
buf.WriteString("[" + nl)
for i := 0; i < v.Len(); i++ {
buf.WriteString(id2)
stringValue(v.Index(i), indent+2, buf)
prettify(v.Index(i), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString("," + nl)
@ -82,7 +82,7 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
for i, k := range v.MapKeys() {
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(k.String() + ": ")
stringValue(v.MapIndex(k), indent+2, buf)
prettify(v.MapIndex(k), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString(",\n")

View file

@ -1,173 +1,239 @@
package aws
import (
"io"
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
// The default number of retries for a service. The value of -1 indicates that
// the service specific retry default will be used.
const DefaultRetries = -1
// DefaultConfig is the default all service configuration will be based off of.
var DefaultConfig = &Config{
Credentials: DefaultChainCredentials,
Endpoint: "",
Region: os.Getenv("AWS_REGION"),
DisableSSL: false,
ManualSend: false,
HTTPClient: http.DefaultClient,
LogHTTPBody: false,
LogLevel: 0,
Logger: os.Stdout,
MaxRetries: DefaultRetries,
DisableParamValidation: false,
DisableComputeChecksums: false,
S3ForcePathStyle: false,
}
// A Config provides service configuration
// A Config provides service configuration for service clients. By default,
// all clients will use the {defaults.DefaultConfig} structure.
type Config struct {
Credentials *credentials.Credentials
Endpoint string
Region string
DisableSSL bool
ManualSend bool
HTTPClient *http.Client
LogHTTPBody bool
LogLevel uint
Logger io.Writer
MaxRetries int
DisableParamValidation bool
DisableComputeChecksums bool
S3ForcePathStyle bool
// The credentials object to use when signing requests. Defaults to
// {defaults.DefaultChainCredentials}.
Credentials *credentials.Credentials
// An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
// to `false`.
DisableSSL *bool
// The HTTP client to use when sending requests. Defaults to
// `http.DefaultClient`.
HTTPClient *http.Client
// An integer value representing the logging level. The default log level
// is zero (LogOff), which represents no logging. To enable logging set
// to a LogLevel Value.
LogLevel *LogLevelType
// The logger writer interface to write logging messages to. Defaults to
// standard out.
Logger Logger
// The maximum number of times that a request will be retried for failures.
// Defaults to -1, which defers the max retry setting to the service specific
// configuration.
MaxRetries *int
// Disables semantic parameter validation, which validates input for missing
// required fields and/or other semantic request input errors.
DisableParamValidation *bool
// Disables the computation of request and response checksums, e.g.,
// CRC32 checksums in Amazon DynamoDB.
DisableComputeChecksums *bool
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will
// use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
SleepDelay func(time.Duration)
}
// Copy will return a shallow copy of the Config object.
func (c Config) Copy() Config {
dst := Config{}
dst.Credentials = c.Credentials
dst.Endpoint = c.Endpoint
dst.Region = c.Region
dst.DisableSSL = c.DisableSSL
dst.ManualSend = c.ManualSend
dst.HTTPClient = c.HTTPClient
dst.LogHTTPBody = c.LogHTTPBody
dst.LogLevel = c.LogLevel
dst.Logger = c.Logger
dst.MaxRetries = c.MaxRetries
dst.DisableParamValidation = c.DisableParamValidation
dst.DisableComputeChecksums = c.DisableComputeChecksums
dst.S3ForcePathStyle = c.S3ForcePathStyle
return dst
// NewConfig returns a new Config pointer that can be chained with builder methods to
// set multiple configuration values inline without using pointers.
//
// svc := s3.New(aws.NewConfig().WithRegion("us-west-2").WithMaxRetries(10))
//
func NewConfig() *Config {
return &Config{}
}
// Merge merges the newcfg attribute values into this Config. Each attribute
// will be merged into this config if the newcfg attribute's value is non-zero.
// Due to this, newcfg attributes with zero values cannot be merged in. For
// example bool attributes cannot be cleared using Merge, and must be explicitly
// set on the Config structure.
func (c Config) Merge(newcfg *Config) *Config {
if newcfg == nil {
// WithCredentials sets a config Credentials value returning a Config pointer
// for chaining.
func (c *Config) WithCredentials(creds *credentials.Credentials) *Config {
c.Credentials = creds
return c
}
// WithEndpoint sets a config Endpoint value returning a Config pointer for
// chaining.
func (c *Config) WithEndpoint(endpoint string) *Config {
c.Endpoint = &endpoint
return c
}
// WithRegion sets a config Region value returning a Config pointer for
// chaining.
func (c *Config) WithRegion(region string) *Config {
c.Region = &region
return c
}
// WithDisableSSL sets a config DisableSSL value returning a Config pointer
// for chaining.
func (c *Config) WithDisableSSL(disable bool) *Config {
c.DisableSSL = &disable
return c
}
// WithHTTPClient sets a config HTTPClient value returning a Config pointer
// for chaining.
func (c *Config) WithHTTPClient(client *http.Client) *Config {
c.HTTPClient = client
return c
}
// WithMaxRetries sets a config MaxRetries value returning a Config pointer
// for chaining.
func (c *Config) WithMaxRetries(max int) *Config {
c.MaxRetries = &max
return c
}
// WithDisableParamValidation sets a config DisableParamValidation value
// returning a Config pointer for chaining.
func (c *Config) WithDisableParamValidation(disable bool) *Config {
c.DisableParamValidation = &disable
return c
}
// WithDisableComputeChecksums sets a config DisableComputeChecksums value
// returning a Config pointer for chaining.
func (c *Config) WithDisableComputeChecksums(disable bool) *Config {
c.DisableComputeChecksums = &disable
return c
}
// WithLogLevel sets a config LogLevel value returning a Config pointer for
// chaining.
func (c *Config) WithLogLevel(level LogLevelType) *Config {
c.LogLevel = &level
return c
}
// WithLogger sets a config Logger value returning a Config pointer for
// chaining.
func (c *Config) WithLogger(logger Logger) *Config {
c.Logger = logger
return c
}
// WithS3ForcePathStyle sets a config S3ForcePathStyle value returning a Config
// pointer for chaining.
func (c *Config) WithS3ForcePathStyle(force bool) *Config {
c.S3ForcePathStyle = &force
return c
}
// WithSleepDelay overrides the function used to sleep while waiting for the
// next retry. Defaults to time.Sleep.
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
c.SleepDelay = fn
return c
}
// Merge returns a new Config with the other Config's attribute values merged into
// this Config. If the other Config's attribute is nil it will not be merged into
// the new Config to be returned.
func (c Config) Merge(other *Config) *Config {
if other == nil {
return &c
}
cfg := Config{}
dst := c
if newcfg.Credentials != nil {
cfg.Credentials = newcfg.Credentials
} else {
cfg.Credentials = c.Credentials
if other.Credentials != nil {
dst.Credentials = other.Credentials
}
if newcfg.Endpoint != "" {
cfg.Endpoint = newcfg.Endpoint
} else {
cfg.Endpoint = c.Endpoint
if other.Endpoint != nil {
dst.Endpoint = other.Endpoint
}
if newcfg.Region != "" {
cfg.Region = newcfg.Region
} else {
cfg.Region = c.Region
if other.Region != nil {
dst.Region = other.Region
}
if newcfg.DisableSSL {
cfg.DisableSSL = newcfg.DisableSSL
} else {
cfg.DisableSSL = c.DisableSSL
if other.DisableSSL != nil {
dst.DisableSSL = other.DisableSSL
}
if newcfg.ManualSend {
cfg.ManualSend = newcfg.ManualSend
} else {
cfg.ManualSend = c.ManualSend
if other.HTTPClient != nil {
dst.HTTPClient = other.HTTPClient
}
if newcfg.HTTPClient != nil {
cfg.HTTPClient = newcfg.HTTPClient
} else {
cfg.HTTPClient = c.HTTPClient
if other.LogLevel != nil {
dst.LogLevel = other.LogLevel
}
if newcfg.LogHTTPBody {
cfg.LogHTTPBody = newcfg.LogHTTPBody
} else {
cfg.LogHTTPBody = c.LogHTTPBody
if other.Logger != nil {
dst.Logger = other.Logger
}
if newcfg.LogLevel != 0 {
cfg.LogLevel = newcfg.LogLevel
} else {
cfg.LogLevel = c.LogLevel
if other.MaxRetries != nil {
dst.MaxRetries = other.MaxRetries
}
if newcfg.Logger != nil {
cfg.Logger = newcfg.Logger
} else {
cfg.Logger = c.Logger
if other.DisableParamValidation != nil {
dst.DisableParamValidation = other.DisableParamValidation
}
if newcfg.MaxRetries != DefaultRetries {
cfg.MaxRetries = newcfg.MaxRetries
} else {
cfg.MaxRetries = c.MaxRetries
if other.DisableComputeChecksums != nil {
dst.DisableComputeChecksums = other.DisableComputeChecksums
}
if newcfg.DisableParamValidation {
cfg.DisableParamValidation = newcfg.DisableParamValidation
} else {
cfg.DisableParamValidation = c.DisableParamValidation
if other.S3ForcePathStyle != nil {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}
if newcfg.DisableComputeChecksums {
cfg.DisableComputeChecksums = newcfg.DisableComputeChecksums
} else {
cfg.DisableComputeChecksums = c.DisableComputeChecksums
if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
if newcfg.S3ForcePathStyle {
cfg.S3ForcePathStyle = newcfg.S3ForcePathStyle
} else {
cfg.S3ForcePathStyle = c.S3ForcePathStyle
}
return &cfg
return &dst
}
// Copy will return a shallow copy of the Config object.
func (c Config) Copy() *Config {
dst := c
return &dst
}

View file

@ -1,92 +0,0 @@
package aws
import (
"net/http"
"os"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Filename: "TestFilename",
Profile: "TestProfile"},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
var copyTestConfig = Config{
Credentials: testCredentials,
Endpoint: "CopyTestEndpoint",
Region: "COPY_TEST_AWS_REGION",
DisableSSL: true,
ManualSend: true,
HTTPClient: http.DefaultClient,
LogHTTPBody: true,
LogLevel: 2,
Logger: os.Stdout,
MaxRetries: DefaultRetries,
DisableParamValidation: true,
DisableComputeChecksums: true,
S3ForcePathStyle: true,
}
func TestCopy(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if !reflect.DeepEqual(got, want) {
t.Errorf("Copy() = %+v", got)
t.Errorf(" want %+v", want)
}
}
func TestCopyReturnsNewInstance(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if &got == &want {
t.Errorf("Copy() = %p; want different instance as source %p", &got, &want)
}
}
var mergeTestZeroValueConfig = Config{MaxRetries: DefaultRetries}
var mergeTestConfig = Config{
Credentials: testCredentials,
Endpoint: "MergeTestEndpoint",
Region: "MERGE_TEST_AWS_REGION",
DisableSSL: true,
ManualSend: true,
HTTPClient: http.DefaultClient,
LogHTTPBody: true,
LogLevel: 2,
Logger: os.Stdout,
MaxRetries: 10,
DisableParamValidation: true,
DisableComputeChecksums: true,
S3ForcePathStyle: true,
}
var mergeTests = []struct {
cfg *Config
in *Config
want *Config
}{
{&Config{}, nil, &Config{}},
{&Config{}, &mergeTestZeroValueConfig, &Config{}},
{&Config{}, &mergeTestConfig, &mergeTestConfig},
}
func TestMerge(t *testing.T) {
for _, tt := range mergeTests {
got := tt.cfg.Merge(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config %+v", tt.cfg)
t.Errorf(" Merge(%+v)", tt.in)
t.Errorf(" got %+v", got)
t.Errorf(" want %+v", tt.want)
}
}
}

357
vendor/github.com/aws/aws-sdk-go/aws/convert_types.go generated vendored Normal file
View file

@ -0,0 +1,357 @@
package aws
import "time"
// String returns a pointer to of the string value passed in.
func String(v string) *string {
return &v
}
// StringValue returns the value of the string pointer passed in or
// "" if the pointer is nil.
func StringValue(v *string) string {
if v != nil {
return *v
}
return ""
}
// StringSlice converts a slice of string values into a slice of
// string pointers
func StringSlice(src []string) []*string {
dst := make([]*string, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// StringValueSlice converts a slice of string pointers into a slice of
// string values
func StringValueSlice(src []*string) []string {
dst := make([]string, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// StringMap converts a string map of string values into a string
// map of string pointers
func StringMap(src map[string]string) map[string]*string {
dst := make(map[string]*string)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// StringValueMap converts a string map of string pointers into a string
// map of string values
func StringValueMap(src map[string]*string) map[string]string {
dst := make(map[string]string)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Bool returns a pointer to of the bool value passed in.
func Bool(v bool) *bool {
return &v
}
// BoolValue returns the value of the bool pointer passed in or
// false if the pointer is nil.
func BoolValue(v *bool) bool {
if v != nil {
return *v
}
return false
}
// BoolSlice converts a slice of bool values into a slice of
// bool pointers
func BoolSlice(src []bool) []*bool {
dst := make([]*bool, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// BoolValueSlice converts a slice of bool pointers into a slice of
// bool values
func BoolValueSlice(src []*bool) []bool {
dst := make([]bool, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// BoolMap converts a string map of bool values into a string
// map of bool pointers
func BoolMap(src map[string]bool) map[string]*bool {
dst := make(map[string]*bool)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// BoolValueMap converts a string map of bool pointers into a string
// map of bool values
func BoolValueMap(src map[string]*bool) map[string]bool {
dst := make(map[string]bool)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int returns a pointer to of the int value passed in.
func Int(v int) *int {
return &v
}
// IntValue returns the value of the int pointer passed in or
// 0 if the pointer is nil.
func IntValue(v *int) int {
if v != nil {
return *v
}
return 0
}
// IntSlice converts a slice of int values into a slice of
// int pointers
func IntSlice(src []int) []*int {
dst := make([]*int, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// IntValueSlice converts a slice of int pointers into a slice of
// int values
func IntValueSlice(src []*int) []int {
dst := make([]int, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// IntMap converts a string map of int values into a string
// map of int pointers
func IntMap(src map[string]int) map[string]*int {
dst := make(map[string]*int)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// IntValueMap converts a string map of int pointers into a string
// map of int values
func IntValueMap(src map[string]*int) map[string]int {
dst := make(map[string]int)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int64 returns a pointer to of the int64 value passed in.
func Int64(v int64) *int64 {
return &v
}
// Int64Value returns the value of the int64 pointer passed in or
// 0 if the pointer is nil.
func Int64Value(v *int64) int64 {
if v != nil {
return *v
}
return 0
}
// Int64Slice converts a slice of int64 values into a slice of
// int64 pointers
func Int64Slice(src []int64) []*int64 {
dst := make([]*int64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int64ValueSlice converts a slice of int64 pointers into a slice of
// int64 values
func Int64ValueSlice(src []*int64) []int64 {
dst := make([]int64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int64Map converts a string map of int64 values into a string
// map of int64 pointers
func Int64Map(src map[string]int64) map[string]*int64 {
dst := make(map[string]*int64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int64ValueMap converts a string map of int64 pointers into a string
// map of int64 values
func Int64ValueMap(src map[string]*int64) map[string]int64 {
dst := make(map[string]int64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float64 returns a pointer to of the float64 value passed in.
func Float64(v float64) *float64 {
return &v
}
// Float64Value returns the value of the float64 pointer passed in or
// 0 if the pointer is nil.
func Float64Value(v *float64) float64 {
if v != nil {
return *v
}
return 0
}
// Float64Slice converts a slice of float64 values into a slice of
// float64 pointers
func Float64Slice(src []float64) []*float64 {
dst := make([]*float64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Float64ValueSlice converts a slice of float64 pointers into a slice of
// float64 values
func Float64ValueSlice(src []*float64) []float64 {
dst := make([]float64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Float64Map converts a string map of float64 values into a string
// map of float64 pointers
func Float64Map(src map[string]float64) map[string]*float64 {
dst := make(map[string]*float64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Float64ValueMap converts a string map of float64 pointers into a string
// map of float64 values
func Float64ValueMap(src map[string]*float64) map[string]float64 {
dst := make(map[string]float64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Time returns a pointer to of the time.Time value passed in.
func Time(v time.Time) *time.Time {
return &v
}
// TimeValue returns the value of the time.Time pointer passed in or
// time.Time{} if the pointer is nil.
func TimeValue(v *time.Time) time.Time {
if v != nil {
return *v
}
return time.Time{}
}
// TimeSlice converts a slice of time.Time values into a slice of
// time.Time pointers
func TimeSlice(src []time.Time) []*time.Time {
dst := make([]*time.Time, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// TimeValueSlice converts a slice of time.Time pointers into a slice of
// time.Time values
func TimeValueSlice(src []*time.Time) []time.Time {
dst := make([]time.Time, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// TimeMap converts a string map of time.Time values into a string
// map of time.Time pointers
func TimeMap(src map[string]time.Time) map[string]*time.Time {
dst := make(map[string]*time.Time)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// TimeValueMap converts a string map of time.Time pointers into a string
// map of time.Time values
func TimeValueMap(src map[string]*time.Time) map[string]time.Time {
dst := make(map[string]time.Time)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}

View file

@ -1,4 +1,4 @@
package aws
package corehandlers
import (
"bytes"
@ -9,15 +9,12 @@ import (
"net/url"
"regexp"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
var sleepDelay = func(delay time.Duration) {
time.Sleep(delay)
}
// Interface for matching types which also have a Len method.
type lener interface {
Len() int
@ -26,7 +23,7 @@ type lener interface {
// BuildContentLength builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *Request) {
var BuildContentLengthHandler = request.NamedHandler{"core.BuildContentLengthHandler", func(r *request.Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
@ -40,27 +37,27 @@ func BuildContentLength(r *Request) {
case lener:
length = int64(body.Len())
case io.Seeker:
r.bodyStart, _ = body.Seek(0, 1)
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.bodyStart, 0) // make sure to seek back to original location
length = end - r.bodyStart
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}
}}
// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *Request) {
r.HTTPRequest.Header.Set("User-Agent", SDKName+"/"+SDKVersion)
}
var UserAgentHandler = request.NamedHandler{"core.UserAgentHandler", func(r *request.Request) {
r.HTTPRequest.Header.Set("User-Agent", aws.SDKName+"/"+aws.SDKVersion)
}}
var reStatusCode = regexp.MustCompile(`^(\d+)`)
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *Request) {
var SendHandler = request.NamedHandler{"core.SendHandler", func(r *request.Request) {
var err error
r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
@ -68,8 +65,8 @@ func SendHandler(r *Request) {
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok {
if s := reStatusCode.FindStringSubmatch(e.Error()); s != nil {
if e, ok := err.(*url.Error); ok && e.Err != nil {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
@ -79,7 +76,7 @@ func SendHandler(r *Request) {
return
}
}
if r.HTTPRequest == nil {
if r.HTTPResponse == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
@ -90,64 +87,50 @@ func SendHandler(r *Request) {
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable.Set(true) // network errors are retryable
r.Retryable = aws.Bool(true) // network errors are retryable
}
}
}}
// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *Request) {
var ValidateResponseHandler = request.NamedHandler{"core.ValidateResponseHandler", func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
}
}
}}
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *Request) {
var AfterRetryHandler = request.NamedHandler{"core.AfterRetryHandler", func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if !r.Retryable.IsSet() {
r.Retryable.Set(r.Service.ShouldRetry(r))
if r.Retryable == nil {
r.Retryable = aws.Bool(r.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.Service.RetryRules(r)
sleepDelay(r.RetryDelay)
r.RetryDelay = r.RetryRules(r)
r.Service.Config.SleepDelay(r.RetryDelay)
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
if isCodeExpiredCreds(err.Code()) {
r.Config.Credentials.Expire()
}
}
if r.IsErrorExpired() {
r.Service.Config.Credentials.Expire()
}
r.RetryCount++
r.Error = nil
}
}
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)
}}
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *Request) {
if r.Service.SigningRegion == "" && r.Service.Config.Region == "" {
r.Error = ErrMissingRegion
var ValidateEndpointHandler = request.NamedHandler{"core.ValidateEndpointHandler", func(r *request.Request) {
if r.Service.SigningRegion == "" && aws.StringValue(r.Service.Config.Region) == "" {
r.Error = aws.ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = ErrMissingEndpoint
r.Error = aws.ErrMissingEndpoint
}
}
}}

View file

@ -0,0 +1,144 @@
package corehandlers
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
var ValidateParametersHandler = request.NamedHandler{"core.ValidateParametersHandler", func(r *request.Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
err := validateField(f, fvalue, validateFieldRequired, validateFieldMin)
if err != nil {
v.errors = append(v.errors, fmt.Sprintf("%s: %s", err.Error(), path+prefix+f.Name))
continue
}
v.validateAny(fvalue, path+prefix+f.Name)
}
}
type validatorFunc func(f reflect.StructField, fvalue reflect.Value) error
func validateField(f reflect.StructField, fvalue reflect.Value, funcs ...validatorFunc) error {
for _, fn := range funcs {
if err := fn(f, fvalue); err != nil {
return err
}
}
return nil
}
// Validates that a field has a valid value provided for required fields.
func validateFieldRequired(f reflect.StructField, fvalue reflect.Value) error {
if f.Tag.Get("required") == "" {
return nil
}
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return fmt.Errorf("missing required parameter")
}
default:
if !fvalue.IsValid() {
return fmt.Errorf("missing required parameter")
}
}
return nil
}
// Validates that if a value is provided for a field, that value must be at
// least a minimum length.
func validateFieldMin(f reflect.StructField, fvalue reflect.Value) error {
minStr := f.Tag.Get("min")
if minStr == "" {
return nil
}
min, _ := strconv.ParseInt(minStr, 10, 64)
kind := fvalue.Kind()
if kind == reflect.Ptr {
if fvalue.IsNil() {
return nil
}
fvalue = fvalue.Elem()
}
switch fvalue.Kind() {
case reflect.String:
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
case reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return nil
}
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
// TODO min can also apply to number minimum value.
}
return nil
}

View file

@ -7,6 +7,8 @@ import (
var (
// ErrNoValidProvidersFoundInChain Is returned when there are no valid
// providers in the ChainProvider.
//
// @readonly
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders", "no valid providers in chain", nil)
)
@ -36,7 +38,9 @@ var (
// &EnvProvider{},
// &EC2RoleProvider{},
// })
// creds.Retrieve()
//
// // Usage of ChainCredentials with aws.Config
// svc := ec2.New(&aws.Config{Credentials: creds})
//
type ChainProvider struct {
Providers []Provider

View file

@ -1,73 +0,0 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
}

View file

@ -63,6 +63,7 @@ import (
// svc := s3.New(&aws.Config{Credentials: AnonymousCredentials})
// // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields.

View file

@ -1,62 +0,0 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type stubProvider struct {
creds Value
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.expired = false
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
}

View file

@ -1,108 +0,0 @@
package credentials
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func initTestServer(expireOn string) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/" {
fmt.Fprintln(w, "/creds")
} else {
fmt.Fprintf(w, `{
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s"
}`, expireOn)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL, ExpiryWindow: time.Hour * 1}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC2RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View file

@ -1,24 +1,25 @@
package credentials
package ec2rolecreds
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
)
const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
// or ExpiryWindow
//
// p := &credentials.EC2RoleProvider{
// p := &ec2rolecreds.EC2RoleProvider{
// // Pass in a custom timeout to be used when requesting
// // IAM EC2 Role credentials.
// Client: &http.Client{
@ -31,15 +32,11 @@ const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam
// // specified the credentials will be expired early
// ExpiryWindow: 0,
// }
//
type EC2RoleProvider struct {
Expiry
credentials.Expiry
// Endpoint must be fully quantified URL
Endpoint string
// HTTP client to use when connecting to EC2 service
Client *http.Client
// EC2Metadata client to use when connecting to EC2 metadata service
Client *ec2metadata.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
@ -53,7 +50,7 @@ type EC2RoleProvider struct {
ExpiryWindow time.Duration
}
// NewEC2RoleCredentials returns a pointer to a new Credentials object
// NewCredentials returns a pointer to a new Credentials object
// wrapping the EC2RoleProvider.
//
// Takes a custom http.Client which can be configured for custom handling of
@ -65,9 +62,8 @@ type EC2RoleProvider struct {
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Duration) *Credentials {
return NewCredentials(&EC2RoleProvider{
Endpoint: endpoint,
func NewCredentials(client *ec2metadata.Client, window time.Duration) *credentials.Credentials {
return credentials.NewCredentials(&EC2RoleProvider{
Client: client,
ExpiryWindow: window,
})
@ -76,32 +72,29 @@ func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Dur
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (Value, error) {
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
if m.Client == nil {
m.Client = http.DefaultClient
}
if m.Endpoint == "" {
m.Endpoint = metadataCredentialsEndpoint
m.Client = ec2metadata.New(nil)
}
credsList, err := requestCredList(m.Client, m.Endpoint)
credsList, err := requestCredList(m.Client)
if err != nil {
return Value{}, err
return credentials.Value{}, err
}
if len(credsList) == 0 {
return Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
return credentials.Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, m.Endpoint, credsName)
roleCreds, err := requestCred(m.Client, credsName)
if err != nil {
return Value{}, err
return credentials.Value{}, err
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
return Value{
return credentials.Value{
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
@ -111,29 +104,35 @@ func (m *EC2RoleProvider) Retrieve() (Value, error) {
// A ec2RoleCredRespBody provides the shape for deserializing credential
// request responses.
type ec2RoleCredRespBody struct {
// Success State
Expiration time.Time
AccessKeyID string
SecretAccessKey string
Token string
// Error state
Code string
Message string
}
const iamSecurityCredsPath = "/iam/security-credentials"
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *http.Client, endpoint string) ([]string, error) {
resp, err := client.Get(endpoint)
func requestCredList(client *ec2metadata.Client) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath)
if err != nil {
return nil, awserr.New("ListEC2Role", "failed to list EC2 Roles", err)
return nil, awserr.New("EC2RoleRequestError", "failed to list EC2 Roles", err)
}
defer resp.Body.Close()
credsList := []string{}
s := bufio.NewScanner(resp.Body)
s := bufio.NewScanner(strings.NewReader(resp))
for s.Scan() {
credsList = append(credsList, s.Text())
}
if err := s.Err(); err != nil {
return nil, awserr.New("ReadEC2Role", "failed to read list of EC2 Roles", err)
return nil, awserr.New("SerializationError", "failed to read list of EC2 Roles", err)
}
return credsList, nil
@ -143,20 +142,26 @@ func requestCredList(client *http.Client, endpoint string) ([]string, error) {
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredRespBody, error) {
resp, err := client.Get(endpoint + credsName)
func requestCred(client *ec2metadata.Client, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName))
if err != nil {
return nil, awserr.New("GetEC2RoleCredentials",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
}
defer resp.Body.Close()
respCreds := &ec2RoleCredRespBody{}
if err := json.NewDecoder(resp.Body).Decode(respCreds); err != nil {
return nil, awserr.New("DecodeEC2RoleCredentials",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{},
awserr.New("SerializationError",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
}
if respCreds.Code != "Success" {
// If an error code was returned something failed requesting the role.
return ec2RoleCredRespBody{}, awserr.New(respCreds.Code, respCreds.Message, nil)
}
return respCreds, nil

View file

@ -9,9 +9,14 @@ import (
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
)
@ -19,8 +24,9 @@ var (
// running process. Environment credentials never expire.
//
// Environment variables used:
// - Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
// - Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
//
// * Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
// * Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
type EnvProvider struct {
retrieved bool
}

View file

@ -1,70 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestEnvProviderRetrieve(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEnvProviderIsExpired(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderAlternateNames(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY", "access")
os.Setenv("AWS_SECRET_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
}

View file

@ -12,6 +12,8 @@ import (
var (
// ErrSharedCredentialsHomeNotFound is emitted when the user directory cannot be found.
//
// @readonly
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil)
)
@ -20,8 +22,12 @@ var (
//
// Profile ini file example: $HOME/.aws/credentials
type SharedCredentialsProvider struct {
// Path to the shared credentials file. If empty will default to current user's
// home directory.
// Path to the shared credentials file.
//
// If empty will look for "AWS_SHARED_CREDENTIALS_FILE" env variable. If the
// env value is empty will default to current user's home directory.
// Linux/OSX: "$HOME/.aws/credentials"
// Windows: "%USERPROFILE%\.aws\credentials"
Filename string
// AWS Profile to extract credentials from the shared credentials file. If empty
@ -104,6 +110,10 @@ func loadProfile(filename, profile string) (Value, error) {
// Will return an error if the user's home directory path cannot be found.
func (p *SharedCredentialsProvider) filename() (string, error) {
if p.Filename == "" {
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); p.Filename != "" {
return p.Filename, nil
}
homeDir := os.Getenv("HOME") // *nix
if homeDir == "" { // Windows
homeDir = os.Getenv("USERPROFILE")

View file

@ -1,77 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestSharedCredentialsProvider(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View file

@ -6,6 +6,8 @@ import (
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
)

View file

@ -1,34 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestStaticProviderGet(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
}
func TestStaticProviderIsExpired(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
}

View file

@ -1,120 +0,0 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
package stscreds
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"time"
)
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
//
// Example how to configure a service to use this provider:
//
// config := &aws.Config{
// Credentials: stscreds.NewCredentials(nil, "arn-of-the-role-to-assume", 10*time.Second),
// })
// // Use config for creating your AWS service.
//
// Example how to obtain customised credentials:
//
// provider := &stscreds.Provider{
// // Extend the duration to 1 hour.
// Duration: time.Hour,
// // Custom role name.
// RoleSessionName: "custom-session-name",
// }
// creds := credentials.NewCredentials(provider)
//
type AssumeRoleProvider struct {
credentials.Expiry
// Custom STS client. If not set the default STS client will be used.
Client AssumeRoler
// Role to be assumed.
RoleARN string
// Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// The sts and roleARN parameters are used for building the "AssumeRole" call.
// Pass nil as sts to use the default client.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewCredentials(client AssumeRoler, roleARN string, window time.Duration) *credentials.Credentials {
return credentials.NewCredentials(&AssumeRoleProvider{
Client: client,
RoleARN: roleARN,
ExpiryWindow: window,
})
}
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.Client == nil {
p.Client = sts.New(nil)
}
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}
if p.Duration == 0 {
// Expire as often as AWS permits.
p.Duration = 15 * time.Minute
}
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Long(int64(p.Duration / time.Second)),
RoleARN: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
})
if err != nil {
return credentials.Value{}, err
}
// We will proactively generate new credentials before they expire.
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyID,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
}, nil
}

View file

@ -1,58 +0,0 @@
package stscreds
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type stubSTS struct {
}
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyID: input.RoleARN,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
},
}, nil
}
func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View file

@ -0,0 +1,39 @@
package defaults
import (
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&ec2rolecreds.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
// DefaultConfig is the default all service configuration will be based off of.
// By default, all clients use this structure for initialization options unless
// a custom configuration object is passed in.
//
// You may modify this global structure to change all default configuration
// in the SDK. Note that configuration options are copied by value, so any
// modifications must happen before constructing a client.
var DefaultConfig = aws.NewConfig().
WithCredentials(DefaultChainCredentials).
WithRegion(os.Getenv("AWS_REGION")).
WithHTTPClient(http.DefaultClient).
WithMaxRetries(aws.DefaultRetries).
WithLogger(aws.NewDefaultLogger()).
WithLogLevel(aws.LogOff).
WithSleepDelay(time.Sleep)

View file

@ -0,0 +1,43 @@
package ec2metadata
import (
"path"
"github.com/aws/aws-sdk-go/aws/request"
)
// GetMetadata uses the path provided to request
func (c *Client) GetMetadata(p string) (string, error) {
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", p),
}
output := &metadataOutput{}
req := request.New(c.Service.ServiceInfo, c.Service.Handlers, c.Service.Retryer, op, nil, output)
return output.Content, req.Send()
}
// Region returns the region the instance is running in.
func (c *Client) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone")
if err != nil {
return "", err
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
}
// Available returns if the application has access to the EC2 Metadata service.
// Can be used to determine if application is running within an EC2 Instance and
// the metadata service is available.
func (c *Client) Available() bool {
if _, err := c.GetMetadata("instance-id"); err != nil {
return false
}
return true
}

View file

@ -0,0 +1,135 @@
package ec2metadata
import (
"io/ioutil"
"net/http"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
)
// DefaultRetries states the default number of times the service client will
// attempt to retry a failed request before failing.
const DefaultRetries = 3
// A Config provides the configuration for the EC2 Metadata service.
type Config struct {
// An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default service endpoint for a client. Set this
// to nil, or `""` to use the default service endpoint.
Endpoint *string
// The HTTP client to use when sending requests. Defaults to
// `http.DefaultClient`.
HTTPClient *http.Client
// An integer value representing the logging level. The default log level
// is zero (LogOff), which represents no logging. To enable logging set
// to a LogLevel Value.
Logger aws.Logger
// The logger writer interface to write logging messages to. Defaults to
// standard out.
LogLevel *aws.LogLevelType
// The maximum number of times that a request will be retried for failures.
// Defaults to DefaultRetries for the number of retries to be performed
// per request.
MaxRetries *int
}
// A Client is an EC2 Metadata service Client.
type Client struct {
*service.Service
}
// New creates a new instance of the EC2 Metadata service client.
//
// In the general use case the configuration for this service client should not
// be needed and `nil` can be provided. Configuration is only needed if the
// `ec2metadata.Config` defaults need to be overridden. Eg. Setting LogLevel.
//
// @note This configuration will NOT be merged with the default AWS service
// client configuration `defaults.DefaultConfig`. Due to circular dependencies
// with the defaults package and credentials EC2 Role Provider.
func New(config *Config) *Client {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: copyConfig(config),
ServiceName: "Client",
Endpoint: "http://169.254.169.254/latest",
APIVersion: "latest",
},
}
service.Initialize()
service.Handlers.Unmarshal.PushBack(unmarshalHandler)
service.Handlers.UnmarshalError.PushBack(unmarshalError)
service.Handlers.Validate.Clear()
service.Handlers.Validate.PushBack(validateEndpointHandler)
return &Client{service}
}
func copyConfig(config *Config) *aws.Config {
if config == nil {
config = &Config{}
}
c := &aws.Config{
Credentials: credentials.AnonymousCredentials,
Endpoint: config.Endpoint,
HTTPClient: config.HTTPClient,
Logger: config.Logger,
LogLevel: config.LogLevel,
MaxRetries: config.MaxRetries,
}
if c.HTTPClient == nil {
c.HTTPClient = http.DefaultClient
}
if c.Logger == nil {
c.Logger = aws.NewDefaultLogger()
}
if c.LogLevel == nil {
c.LogLevel = aws.LogLevel(aws.LogOff)
}
if c.MaxRetries == nil {
c.MaxRetries = aws.Int(DefaultRetries)
}
return c
}
type metadataOutput struct {
Content string
}
func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
}
data := r.Data.(*metadataOutput)
data.Content = string(b)
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
_, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
}
// TODO extract the error...
}
func validateEndpointHandler(r *request.Request) {
if r.Service.Endpoint == "" {
r.Error = aws.ErrMissingEndpoint
}
}

17
vendor/github.com/aws/aws-sdk-go/aws/errors.go generated vendored Normal file
View file

@ -0,0 +1,17 @@
package aws
import "github.com/aws/aws-sdk-go/aws/awserr"
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)

View file

@ -1,81 +0,0 @@
package aws
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := NewService(&Config{Region: "us-west-2"})
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := NewService(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retreiveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retreiveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := NewService(&Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: 1})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)
})
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retreiveCalled)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retreiveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retreiveCalled)
}

View file

@ -1,31 +0,0 @@
package aws
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) { r.Data = nil })
l.PushFront(func(r *Request) { r.Data = Boolean(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}

98
vendor/github.com/aws/aws-sdk-go/aws/logger.go generated vendored Normal file
View file

@ -0,0 +1,98 @@
package aws
import (
"log"
"os"
)
// A LogLevelType defines the level logging should be performed at. Used to instruct
// the SDK which statements should be logged.
type LogLevelType uint
// LogLevel returns the pointer to a LogLevel. Should be used to workaround
// not being able to take the address of a non-composite literal.
func LogLevel(l LogLevelType) *LogLevelType {
return &l
}
// Value returns the LogLevel value or the default value LogOff if the LogLevel
// is nil. Safe to use on nil value LogLevelTypes.
func (l *LogLevelType) Value() LogLevelType {
if l != nil {
return *l
}
return LogOff
}
// Matches returns true if the v LogLevel is enabled by this LogLevel. Should be
// used with logging sub levels. Is safe to use on nil value LogLevelTypes. If
// LogLevel is nill, will default to LogOff comparison.
func (l *LogLevelType) Matches(v LogLevelType) bool {
c := l.Value()
return c&v == v
}
// AtLeast returns true if this LogLevel is at least high enough to satisfies v.
// Is safe to use on nil value LogLevelTypes. If LogLevel is nill, will default
// to LogOff comparison.
func (l *LogLevelType) AtLeast(v LogLevelType) bool {
c := l.Value()
return c >= v
}
const (
// LogOff states that no logging should be performed by the SDK. This is the
// default state of the SDK, and should be use to disable all logging.
LogOff LogLevelType = iota * 0x1000
// LogDebug state that debug output should be logged by the SDK. This should
// be used to inspect request made and responses received.
LogDebug
)
// Debug Logging Sub Levels
const (
// LogDebugWithSigning states that the SDK should log request signing and
// presigning events. This should be used to log the signing details of
// requests for debugging. Will also enable LogDebug.
LogDebugWithSigning LogLevelType = LogDebug | (1 << iota)
// LogDebugWithHTTPBody states the SDK should log HTTP request and response
// HTTP bodys in addition to the headers and path. This should be used to
// see the body content of requests and responses made while using the SDK
// Will also enable LogDebug.
LogDebugWithHTTPBody
// LogDebugWithRequestRetries states the SDK should log when service requests will
// be retried. This should be used to log when you want to log when service
// requests are being retried. Will also enable LogDebug.
LogDebugWithRequestRetries
// LogDebugWithRequestErrors states the SDK should log when service requests fail
// to build, send, validate, or unmarshal.
LogDebugWithRequestErrors
)
// A Logger is a minimalistic interface for the SDK to log messages to. Should
// be used to provide custom logging writers for the SDK to use.
type Logger interface {
Log(...interface{})
}
// NewDefaultLogger returns a Logger which will write log messages to stdout, and
// use same formatting runes as the stdlib log.Logger
func NewDefaultLogger() Logger {
return &defaultLogger{
logger: log.New(os.Stdout, "", log.LstdFlags),
}
}
// A defaultLogger provides a minimalistic logger satisfying the Logger interface.
type defaultLogger struct {
logger *log.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l defaultLogger) Log(args ...interface{}) {
l.logger.Println(args...)
}

View file

@ -1,89 +0,0 @@
package aws
import (
"fmt"
"reflect"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
notset := false
if f.Tag.Get("required") != "" {
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
notset = true
}
default:
if !fvalue.IsValid() {
notset = true
}
}
}
if notset {
msg := "missing required parameter: " + path + prefix + f.Name
v.errors = append(v.errors, msg)
} else {
v.validateAny(fvalue, path+prefix+f.Name)
}
}
}

View file

@ -1,84 +0,0 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
var service = func() *aws.Service {
s := &aws.Service{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Boolean(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Boolean(true),
OptionalStruct: &ConditionalStructShape{},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
}

View file

@ -1,4 +1,4 @@
package aws
package request
// A Handlers provides a collection of request handlers for various
// stages of handling requests.
@ -15,8 +15,8 @@ type Handlers struct {
AfterRetry HandlerList
}
// copy returns of this handler's lists.
func (h *Handlers) copy() Handlers {
// Copy returns of this handler's lists.
func (h *Handlers) Copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
@ -47,19 +47,25 @@ func (h *Handlers) Clear() {
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []func(*Request)
list []NamedHandler
}
// A NamedHandler is a struct that contains a name and function callback.
type NamedHandler struct {
Name string
Fn func(*Request)
}
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]func(*Request){}, l.list...)
n.list = append([]NamedHandler{}, l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []func(*Request){}
l.list = []NamedHandler{}
}
// Len returns the number of handlers in the list.
@ -67,19 +73,40 @@ func (l *HandlerList) Len() int {
return len(l.list)
}
// PushBack pushes handlers f to the back of the handler list.
func (l *HandlerList) PushBack(f ...func(*Request)) {
l.list = append(l.list, f...)
// PushBack pushes handler f to the back of the handler list.
func (l *HandlerList) PushBack(f func(*Request)) {
l.list = append(l.list, NamedHandler{"__anonymous", f})
}
// PushFront pushes handlers f to the front of the handler list.
func (l *HandlerList) PushFront(f ...func(*Request)) {
l.list = append(f, l.list...)
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...)
}
// PushBackNamed pushes named handler f to the back of the handler list.
func (l *HandlerList) PushBackNamed(n NamedHandler) {
l.list = append(l.list, n)
}
// PushFrontNamed pushes named handler f to the front of the handler list.
func (l *HandlerList) PushFrontNamed(n NamedHandler) {
l.list = append([]NamedHandler{n}, l.list...)
}
// Remove removes a NamedHandler n
func (l *HandlerList) Remove(n NamedHandler) {
newlist := []NamedHandler{}
for _, m := range l.list {
if m.Name != n.Name {
newlist = append(newlist, m)
}
}
l.list = newlist
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f(r)
f.Fn(r)
}
}

View file

@ -1,7 +1,8 @@
package aws
package request
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
@ -10,12 +11,15 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
)
// A Request is the service request to be made.
type Request struct {
*Service
Retryer
Service serviceinfo.ServiceInfo
Handlers Handlers
Time time.Time
ExpireTime time.Duration
@ -23,13 +27,13 @@ type Request struct {
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
bodyStart int64 // offset from beginning of Body that the request body starts
BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
RequestID string
RetryCount uint
Retryable SettableBool
Retryable *bool
RetryDelay time.Duration
built bool
@ -51,13 +55,13 @@ type Paginator struct {
TruncationToken string
}
// NewRequest returns a new Request pointer for the service API
// New returns a new Request pointer for the service API
// operation and parameters.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
// payload will be deserialized to.
func NewRequest(service *Service, operation *Operation, params interface{}, data interface{}) *Request {
func New(service serviceinfo.ServiceInfo, handlers Handlers, retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
method := operation.HTTPMethod
if method == "" {
method = "POST"
@ -71,8 +75,9 @@ func NewRequest(service *Service, operation *Operation, params interface{}, data
httpReq.URL, _ = url.Parse(service.Endpoint + p)
r := &Request{
Retryer: retryer,
Service: service,
Handlers: service.Handlers.copy(),
Handlers: handlers.Copy(),
Time: time.Now(),
ExpireTime: 0,
Operation: operation,
@ -89,7 +94,7 @@ func NewRequest(service *Service, operation *Operation, params interface{}, data
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
return r.Error != nil && r.Retryable.Get() && r.RetryCount < r.Service.MaxRetries()
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}
// ParamsFilled returns if the request's parameters have been populated
@ -134,6 +139,20 @@ func (r *Request) Presign(expireTime time.Duration) (string, error) {
return r.HTTPRequest.URL.String(), nil
}
func debugLogReqError(r *Request, stage string, retrying bool, err error) {
if !r.Service.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return
}
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Service.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.Service.ServiceName, r.Operation.Name, retryStr, err))
}
// Build will build the request's object so it can be signed and sent
// to the service. Build will also validate all the request's parameters.
// Anny additional build Handlers set on this request will be run
@ -149,6 +168,7 @@ func (r *Request) Build() error {
r.Error = nil
r.Handlers.Validate.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error)
return r.Error
}
r.Handlers.Build.Run(r)
@ -165,6 +185,7 @@ func (r *Request) Build() error {
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error)
return r.Error
}
@ -183,42 +204,57 @@ func (r *Request) Send() error {
return r.Error
}
if r.Retryable.Get() {
if aws.BoolValue(r.Retryable) {
if r.Service.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Service.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.Service.ServiceName, r.Operation.Name, r.RetryCount))
}
// Re-seek the body back to the original point in for a retry so that
// send will send the body's contents again in the upcoming request.
r.Body.Seek(r.bodyStart, 0)
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
}
r.Retryable.Reset()
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request", false, r.Error)
return r.Error
}
debugLogReqError(r, "Send Request", true, err)
continue
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.UnmarshalError.Run(r)
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Response", false, r.Error)
return r.Error
}
debugLogReqError(r, "Validate Response", true, err)
continue
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", false, r.Error)
return r.Error
}
debugLogReqError(r, "Unmarshal Response", true, err)
continue
}
@ -279,7 +315,7 @@ func (r *Request) NextPage() *Request {
}
data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface()
nr := NewRequest(r.Service, r.Operation, awsutil.CopyOf(r.Params), data)
nr := New(r.Service, r.Handlers, r.Retryer, r.Operation, awsutil.CopyOf(r.Params), data)
for i, intok := range nr.Operation.InputTokens {
awsutil.SetValueAtAnyPath(nr.Params, intok, tokens[i])
}

View file

@ -0,0 +1,71 @@
package request
import (
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Retryer is an interface to control retry logic for a given service.
// The default implementation used by most services is the service.DefaultRetryer
// structure, which contains basic retry logic using exponential backoff.
type Retryer interface {
RetryRules(*Request) time.Duration
ShouldRetry(*Request) bool
MaxRetries() uint
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if the request has no Error set.
func (r *Request) IsErrorRetryable() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set.
func (r *Request) IsErrorExpired() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeExpiredCreds(err.Code())
}
}
return false
}

View file

@ -1,305 +0,0 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
// Use DynamoDB methods for simplicity
func TestPagination(t *testing.T) {
db := dynamodb.New(nil)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Long(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
for _, t := range p.TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
assert.Nil(t, params.ExclusiveStartTableName)
}
// Use DynamoDB methods for simplicity
func TestPaginationEachPage(t *testing.T) {
db := dynamodb.New(nil)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Long(2)}
req, _ := db.ListTablesRequest(params)
err := req.EachPage(func(p interface{}, last bool) bool {
numPages++
for _, t := range p.(*dynamodb.ListTablesOutput).TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
}
// Use DynamoDB methods for simplicity
func TestPaginationEarlyExit(t *testing.T) {
db := dynamodb.New(nil)
numPages, gotToEnd := 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Long(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
if numPages == 2 {
return false
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, 2, numPages)
assert.False(t, gotToEnd)
assert.Nil(t, err)
}
func TestSkipPagination(t *testing.T) {
client := s3.New(nil)
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = &s3.HeadBucketOutput{}
})
req, _ := client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: aws.String("bucket")})
numPages, gotToEnd := 0, false
req.EachPage(func(p interface{}, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
assert.Equal(t, 1, numPages)
assert.True(t, gotToEnd)
}
// Use S3 for simplicity
func TestPaginationTruncation(t *testing.T) {
count := 0
client := s3.New(nil)
reqNum := &count
resps := []*s3.ListObjectsOutput{
{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}},
{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}},
{IsTruncated: aws.Boolean(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}},
{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[*reqNum]
*reqNum++
})
params := &s3.ListObjectsInput{Bucket: aws.String("bucket")}
results := []string{}
err := client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2", "Key3"}, results)
assert.Nil(t, err)
// Try again without truncation token at all
count = 0
resps[1].IsTruncated = nil
resps[2].IsTruncated = aws.Boolean(true)
results = []string{}
err = client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2"}, results)
assert.Nil(t, err)
}
// Benchmarks
var benchResps = []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE")}},
}
var benchDb = func() *dynamodb.DynamoDB {
db := dynamodb.New(nil)
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
return db
}
func BenchmarkCodegenIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Long(2)}
iter := func(fn func(*dynamodb.ListTablesOutput, bool) bool) error {
page, _ := db.ListTablesRequest(input)
for ; page != nil; page = page.NextPage() {
page.Send()
out := page.Data.(*dynamodb.ListTablesOutput)
if result := fn(out, !page.HasNextPage()); page.Error != nil || !result {
return page.Error
}
}
return nil
}
for i := 0; i < b.N; i++ {
reqNum = 0
iter(func(p *dynamodb.ListTablesOutput, last bool) bool {
return true
})
}
}
func BenchmarkEachPageIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Long(2)}
for i := 0; i < b.N; i++ {
reqNum = 0
req, _ := db.ListTablesRequest(input)
req.EachPage(func(p interface{}, last bool) bool {
return true
})
}
}

View file

@ -1,219 +0,0 @@
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
type testData struct {
Data string
}
func body(str string) io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader([]byte(str)))
}
func unmarshal(req *Request) {
defer req.HTTPResponse.Body.Close()
if req.Data != nil {
json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data)
}
return
}
func unmarshalError(req *Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err)
return
}
if len(bodyBytes) == 0 {
req.Error = awserr.NewRequestFailure(
awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")),
req.HTTPResponse.StatusCode,
"",
)
return
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err)
return
}
req.Error = awserr.NewRequestFailure(
awserr.New(jsonErr.Code, jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
"",
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}
// test that retries occur for 5xx status codes
func TestRequestRecoverRetry5xx(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry`
func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries don't occur for 4xx status codes with a response type that can't be retried
func TestRequest4xxUnretryable(t *testing.T) {
s := NewService(&Config{MaxRetries: 10})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)}
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
assert.Equal(t, 401, e.StatusCode())
} else {
assert.Fail(t, "Expected error to be a service failure")
}
assert.Equal(t, "SignatureDoesNotMatch", err.(awserr.Error).Code())
assert.Equal(t, "Signature does not match.", err.(awserr.Error).Message())
assert.Equal(t, 0, int(r.RetryCount))
}
func TestRequestExhaustRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay = func(delay time.Duration) {
delays = append(delays, delay)
}
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := NewService(&Config{MaxRetries: -1})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := NewRequest(s, &Operation{Name: "Operation"}, nil, nil)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
assert.Equal(t, 500, e.StatusCode())
} else {
assert.Fail(t, "Expected error to be a service failure")
}
assert.Equal(t, "UnknownError", err.(awserr.Error).Code())
assert.Equal(t, "An error occurred.", err.(awserr.Error).Message())
assert.Equal(t, 3, int(r.RetryCount))
assert.True(t, reflect.DeepEqual([]time.Duration{30 * time.Millisecond, 60 * time.Millisecond, 120 * time.Millisecond}, delays))
}
// test that the request is retried after the credentials are expired.
func TestRequestRecoverExpiredCreds(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10, Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
credExpiredBeforeRetry := false
credExpiredAfterRetry := false
s.Handlers.AfterRetry.PushBack(func(r *Request) {
credExpiredAfterRetry = r.Config.Credentials.IsExpired()
})
s.Handlers.Sign.Clear()
s.Handlers.Sign.PushBack(func(r *Request) {
r.Config.Credentials.Get()
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check")
assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check")
assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery")
assert.Equal(t, 1, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}

View file

@ -1,177 +0,0 @@
package aws
import (
"fmt"
"math"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
Config *Config
Handlers Handlers
ManualSend bool
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
RetryRules func(*Request) time.Duration
ShouldRetry func(*Request) bool
DefaultMaxRetries uint
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// NewService will return a pointer to a new Server object initialized.
func NewService(config *Config) *Service {
svc := &Service{Config: config}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.RetryRules == nil {
s.RetryRules = retryRules
}
if s.ShouldRetry == nil {
s.ShouldRetry = shouldRetry
}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBack(ValidateEndpointHandler)
s.Handlers.Build.PushBack(UserAgentHandler)
s.Handlers.Sign.PushBack(BuildContentLength)
s.Handlers.Send.PushBack(SendHandler)
s.Handlers.AfterRetry.PushBack(AfterRetryHandler)
s.Handlers.ValidateResponse.PushBack(ValidateResponseHandler)
s.AddDebugHandlers()
s.buildEndpoint()
if !s.Config.DisableParamValidation {
s.Handlers.Validate.PushBack(ValidateParameters)
}
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if s.Config.Endpoint != "" {
s.Endpoint = s.Config.Endpoint
} else {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, s.Config.Region)
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if s.Config.DisableSSL {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
out := s.Config.Logger
if s.Config.LogLevel == 0 {
return
}
s.Handlers.Send.PushFront(func(r *Request) {
logBody := r.Config.LogHTTPBody
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
fmt.Fprintf(out, "---[ REQUEST POST-SIGN ]-----------------------------\n")
fmt.Fprintf(out, "%s\n", string(dumpedBody))
fmt.Fprintf(out, "-----------------------------------------------------\n")
})
s.Handlers.Send.PushBack(func(r *Request) {
fmt.Fprintf(out, "---[ RESPONSE ]--------------------------------------\n")
if r.HTTPResponse != nil {
logBody := r.Config.LogHTTPBody
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
fmt.Fprintf(out, "%s\n", string(dumpedBody))
} else if r.Error != nil {
fmt.Fprintf(out, "%s\n", r.Error)
}
fmt.Fprintf(out, "-----------------------------------------------------\n")
})
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (s *Service) MaxRetries() uint {
if s.Config.MaxRetries < 0 {
return s.DefaultMaxRetries
}
return uint(s.Config.MaxRetries)
}
// retryRules returns the delay duration before retrying this request again
func retryRules(r *Request) time.Duration {
delay := time.Duration(math.Pow(2, float64(r.RetryCount))) * 30
return delay * time.Millisecond
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// shouldRetry returns if the request should be retried.
func shouldRetry(r *Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}

View file

@ -0,0 +1,51 @@
package service
import (
"math"
"math/rand"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the
// request.Retryer interface or create a structure type that composes this
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
//
// type retryer struct {
// service.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() uint { return 100 }
type DefaultRetryer struct {
*Service
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (d DefaultRetryer) MaxRetries() uint {
if aws.IntValue(d.Service.Config.MaxRetries) < 0 {
return d.DefaultMaxRetries
}
return uint(aws.IntValue(d.Service.Config.MaxRetries))
}
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (seededRand.Intn(30) + 30)
return time.Duration(delay) * time.Millisecond
}
// ShouldRetry returns if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
return r.IsErrorRetryable()
}

133
vendor/github.com/aws/aws-sdk-go/aws/service/service.go generated vendored Normal file
View file

@ -0,0 +1,133 @@
package service
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/aws/aws-sdk-go/internal/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
serviceinfo.ServiceInfo
request.Retryer
DefaultMaxRetries uint
Handlers request.Handlers
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// New will return a pointer to a new Server object initialized.
func New(config *aws.Config) *Service {
svc := &Service{ServiceInfo: serviceinfo.ServiceInfo{Config: config}}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &aws.Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.Config.SleepDelay == nil {
s.Config.SleepDelay = time.Sleep
}
s.Retryer = DefaultRetryer{s}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
s.Handlers.Build.PushBackNamed(corehandlers.UserAgentHandler)
s.Handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
s.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
s.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
s.Handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler)
if !aws.BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateParametersHandler)
}
s.AddDebugHandlers()
s.buildEndpoint()
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
func (s *Service) NewRequest(operation *request.Operation, params interface{}, data interface{}) *request.Request {
return request.New(s.ServiceInfo, s.Handlers, s.Retryer, operation, params, data)
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if aws.StringValue(s.Config.Endpoint) != "" {
s.Endpoint = *s.Config.Endpoint
} else if s.Endpoint == "" {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, aws.StringValue(s.Config.Region))
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if aws.BoolValue(s.Config.DisableSSL) {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
if !s.Config.LogLevel.AtLeast(aws.LogDebug) {
return
}
s.Handlers.Send.PushFront(logRequest)
s.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *request.Request) {
logBody := r.Service.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
}
r.Service.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.Service.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Service.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Service.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.Service.ServiceName, r.Operation.Name, msg))
}

View file

@ -0,0 +1,15 @@
package serviceinfo
import "github.com/aws/aws-sdk-go/aws"
// ServiceInfo wraps immutable data from the service.Service structure.
type ServiceInfo struct {
Config *aws.Config
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
}

View file

@ -1,36 +1,10 @@
package aws
import (
"fmt"
"io"
"time"
"sync"
)
// String converts a Go string into a string pointer.
func String(v string) *string {
return &v
}
// Boolean converts a Go bool into a boolean pointer.
func Boolean(v bool) *bool {
return &v
}
// Long converts a Go int64 into a long pointer.
func Long(v int64) *int64 {
return &v
}
// Double converts a Go float64 into a double pointer.
func Double(v float64) *float64 {
return &v
}
// Time converts a Go Time into a Time pointer
func Time(t time.Time) *time.Time {
return &t
}
// ReadSeekCloser wraps a io.Reader returning a ReaderSeakerCloser
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
@ -81,51 +55,34 @@ func (r ReaderSeekerCloser) Close() error {
return nil
}
// A SettableBool provides a boolean value which includes the state if
// the value was set or unset. The set state is in addition to the value's
// value(true|false)
type SettableBool struct {
value bool
set bool
// A WriteAtBuffer provides a in memory buffer supporting the io.WriterAt interface
// Can be used with the s3manager.Downloader to download content to a buffer
// in memory. Safe to use concurrently.
type WriteAtBuffer struct {
buf []byte
m sync.Mutex
}
// SetBool returns a SettableBool with a value set
func SetBool(value bool) SettableBool {
return SettableBool{value: value, set: true}
}
// WriteAt writes a slice of bytes to a buffer starting at the position provided
// The number of bytes written will be returned, or error. Can overwrite previous
// written slices if the write ats overlap.
func (b *WriteAtBuffer) WriteAt(p []byte, pos int64) (n int, err error) {
b.m.Lock()
defer b.m.Unlock()
// Get returns the value. Will always be false if the SettableBool was not set.
func (b *SettableBool) Get() bool {
if !b.set {
return false
expLen := pos + int64(len(p))
if int64(len(b.buf)) < expLen {
newBuf := make([]byte, expLen)
copy(newBuf, b.buf)
b.buf = newBuf
}
return b.value
copy(b.buf[pos:], p)
return len(p), nil
}
// Set sets the value and updates the state that the value has been set.
func (b *SettableBool) Set(value bool) {
b.value = value
b.set = true
}
// IsSet returns if the value has been set
func (b *SettableBool) IsSet() bool {
return b.set
}
// Reset resets the state and value of the SettableBool to its initial default
// state of not set and zero value.
func (b *SettableBool) Reset() {
b.value = false
b.set = false
}
// String returns the string representation of the value if set. Zero if not set.
func (b *SettableBool) String() string {
return fmt.Sprintf("%t", b.Get())
}
// GoString returns the string representation of the SettableBool value and state
func (b *SettableBool) GoString() string {
return fmt.Sprintf("Bool{value:%t, set:%t}", b.value, b.set)
// Bytes returns a slice of bytes written to the buffer.
func (b *WriteAtBuffer) Bytes() []byte {
b.m.Lock()
defer b.m.Unlock()
return b.buf[:len(b.buf):len(b.buf)]
}

View file

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "0.6.4"
const SDKVersion = "0.9.10"

View file

@ -1,28 +0,0 @@
package endpoints
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGlobalEndpoints(t *testing.T) {
region := "mock-region-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts"}
for _, name := range svcs {
ep, sr := EndpointForRegion(name, region)
assert.Equal(t, name+".amazonaws.com", ep)
assert.Equal(t, "us-east-1", sr)
}
}
func TestServicesInCN(t *testing.T) {
region := "cn-north-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts", "s3"}
for _, name := range svcs {
ep, _ := EndpointForRegion(name, region)
assert.Equal(t, name+"."+region+".amazonaws.com.cn", ep)
}
}

View file

@ -6,13 +6,13 @@ package query
import (
"net/url"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/protocol/query/queryutil"
)
// Build builds a request for an AWS Query service.
func Build(r *aws.Request) {
func Build(r *request.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.Service.APIVersion},

File diff suppressed because it is too large Load diff

View file

@ -5,13 +5,13 @@ package query
import (
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response for an AWS Query service.
func Unmarshal(r *aws.Request) {
func Unmarshal(r *request.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
@ -24,6 +24,6 @@ func Unmarshal(r *aws.Request) {
}
// UnmarshalMeta unmarshals header response values for an AWS Query service.
func UnmarshalMeta(r *aws.Request) {
func UnmarshalMeta(r *request.Request) {
// TODO implement unmarshaling of request IDs
}

View file

@ -4,8 +4,8 @@ import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
type xmlErrorResponse struct {
@ -16,7 +16,7 @@ type xmlErrorResponse struct {
}
// UnmarshalError unmarshals an error response for an AWS Query service.
func UnmarshalError(r *aws.Request) {
func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
// Package rest provides RESTful serialisation of AWS requests and responses.
// Package rest provides RESTful serialization of AWS requests and responses.
package rest
import (
@ -13,8 +13,8 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
@ -37,7 +37,7 @@ func init() {
}
// Build builds the REST component of a service request.
func Build(r *aws.Request) {
func Build(r *request.Request) {
if r.ParamsFilled() {
v := reflect.ValueOf(r.Params).Elem()
buildLocationElements(r, v)
@ -45,7 +45,7 @@ func Build(r *aws.Request) {
}
}
func buildLocationElements(r *aws.Request, v reflect.Value) {
func buildLocationElements(r *request.Request, v reflect.Value) {
query := r.HTTPRequest.URL.Query()
for i := 0; i < v.NumField(); i++ {
@ -87,7 +87,7 @@ func buildLocationElements(r *aws.Request, v reflect.Value) {
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path)
}
func buildBody(r *aws.Request, v reflect.Value) {
func buildBody(r *request.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
@ -112,7 +112,7 @@ func buildBody(r *aws.Request, v reflect.Value) {
}
}
func buildHeader(r *aws.Request, v reflect.Value, name string) {
func buildHeader(r *request.Request, v reflect.Value, name string) {
str, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
@ -121,7 +121,7 @@ func buildHeader(r *aws.Request, v reflect.Value, name string) {
}
}
func buildHeaderMap(r *aws.Request, v reflect.Value, prefix string) {
func buildHeaderMap(r *request.Request, v reflect.Value, prefix string) {
for _, key := range v.MapKeys() {
str, err := convertType(v.MapIndex(key))
if err != nil {
@ -132,7 +132,7 @@ func buildHeaderMap(r *aws.Request, v reflect.Value, prefix string) {
}
}
func buildURI(r *aws.Request, v reflect.Value, name string) {
func buildURI(r *request.Request, v reflect.Value, name string) {
value, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
@ -144,7 +144,7 @@ func buildURI(r *aws.Request, v reflect.Value, name string) {
}
}
func buildQueryString(r *aws.Request, v reflect.Value, name string, query url.Values) {
func buildQueryString(r *request.Request, v reflect.Value, name string, query url.Values) {
str, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
@ -156,8 +156,13 @@ func buildQueryString(r *aws.Request, v reflect.Value, name string, query url.Va
func updatePath(url *url.URL, urlPath string) {
scheme, query := url.Scheme, url.RawQuery
hasSlash := strings.HasSuffix(urlPath, "/")
// clean up path
urlPath = path.Clean(urlPath)
if hasSlash && !strings.HasSuffix(urlPath, "/") {
urlPath += "/"
}
// get formatted URL minus scheme so we can build this into Opaque
url.Scheme, url.Path, url.RawQuery = "", "", ""

View file

@ -12,18 +12,27 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// Unmarshal unmarshals the REST component of a response in a REST service.
func Unmarshal(r *aws.Request) {
func Unmarshal(r *request.Request) {
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalBody(r, v)
}
}
// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
func UnmarshalMeta(r *request.Request) {
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalLocationElements(r, v)
}
}
func unmarshalBody(r *aws.Request, v reflect.Value) {
func unmarshalBody(r *request.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
@ -64,7 +73,7 @@ func unmarshalBody(r *aws.Request, v reflect.Value) {
}
}
func unmarshalLocationElements(r *aws.Request, v reflect.Value) {
func unmarshalLocationElements(r *request.Request, v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
m, field := v.Field(i), v.Type().Field(i)
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {

File diff suppressed because it is too large Load diff

View file

@ -9,22 +9,22 @@ import (
"bytes"
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/protocol/rest"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
)
// Build builds a request payload for the REST XML protocol.
func Build(r *aws.Request) {
func Build(r *request.Request) {
rest.Build(r)
if t := rest.PayloadType(r.Params); t == "structure" || t == "" {
var buf bytes.Buffer
err := xmlutil.BuildXML(r.Params, xml.NewEncoder(&buf))
if err != nil {
r.Error = awserr.New("SerializationError", "failed to enode rest XML request", err)
r.Error = awserr.New("SerializationError", "failed to encode rest XML request", err)
return
}
r.SetBufferBody(buf.Bytes())
@ -32,7 +32,7 @@ func Build(r *aws.Request) {
}
// Unmarshal unmarshals a payload response for the REST XML protocol.
func Unmarshal(r *aws.Request) {
func Unmarshal(r *request.Request) {
if t := rest.PayloadType(r.Data); t == "structure" || t == "" {
defer r.HTTPResponse.Body.Close()
decoder := xml.NewDecoder(r.HTTPResponse.Body)
@ -41,15 +41,17 @@ func Unmarshal(r *aws.Request) {
r.Error = awserr.New("SerializationError", "failed to decode REST XML response", err)
return
}
} else {
rest.Unmarshal(r)
}
}
// UnmarshalMeta unmarshals response headers for the REST XML protocol.
func UnmarshalMeta(r *aws.Request) {
rest.Unmarshal(r)
func UnmarshalMeta(r *request.Request) {
rest.UnmarshalMeta(r)
}
// UnmarshalError unmarshals a response error for the REST XML protocol.
func UnmarshalError(r *aws.Request) {
func UnmarshalError(r *request.Request) {
query.UnmarshalError(r)
}

File diff suppressed because it is too large Load diff

View file

@ -1,43 +0,0 @@
package v4_test
import (
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestPresignHandler(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
assert.NoError(t, err)
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-acl"
expectedSig := "7edcb4e3a1bf12f4989018d75acbe3a7f03df24bd6f3112602d59fc551f0e4e2"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}

View file

@ -14,10 +14,10 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/protocol/rest"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/protocol/rest"
)
const (
@ -43,8 +43,8 @@ type signer struct {
Credentials *credentials.Credentials
Query url.Values
Body io.ReadSeeker
Debug uint
Logger io.Writer
Debug aws.LogLevelType
Logger aws.Logger
isPresign bool
formattedTime string
@ -64,7 +64,7 @@ type signer struct {
// Will sign the requests with the service config's Credentials object
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
// object.
func Sign(req *aws.Request) {
func Sign(req *request.Request) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Service.Config.Credentials == credentials.AnonymousCredentials {
@ -73,7 +73,7 @@ func Sign(req *aws.Request) {
region := req.Service.SigningRegion
if region == "" {
region = req.Service.Config.Region
region = aws.StringValue(req.Service.Config.Region)
}
name := req.Service.SigningName
@ -90,7 +90,7 @@ func Sign(req *aws.Request) {
ServiceName: name,
Region: region,
Credentials: req.Service.Config.Credentials,
Debug: req.Service.Config.LogLevel,
Debug: req.Service.Config.LogLevel.Value(),
Logger: req.Service.Config.Logger,
}
@ -138,28 +138,33 @@ func (v4 *signer) sign() error {
v4.build()
if v4.Debug > 0 {
if v4.Debug.Matches(aws.LogDebugWithSigning) {
v4.logSigningInfo()
}
return nil
}
const logSignInfoMsg = `DEBUG: Request Signiture:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func (v4 *signer) logSigningInfo() {
out := v4.Logger
fmt.Fprintf(out, "---[ CANONICAL STRING ]-----------------------------\n")
fmt.Fprintln(out, v4.canonicalString)
fmt.Fprintf(out, "---[ STRING TO SIGN ]--------------------------------\n")
fmt.Fprintln(out, v4.stringToSign)
signedURLMsg := ""
if v4.isPresign {
fmt.Fprintf(out, "---[ SIGNED URL ]--------------------------------\n")
fmt.Fprintln(out, v4.Request.URL)
signedURLMsg = fmt.Sprintf(logSignedURLMsg, v4.Request.URL.String())
}
fmt.Fprintf(out, "-----------------------------------------------------\n")
msg := fmt.Sprintf(logSignInfoMsg, v4.canonicalString, v4.stringToSign, signedURLMsg)
v4.Logger.Log(msg)
}
func (v4 *signer) build() {
v4.buildTime() // no depends
v4.buildCredentialString() // no depends
if v4.isPresign {

View file

@ -1,245 +0,0 @@
package v4
import (
"net/http"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func buildSigner(serviceName string, region string, signTime time.Time, expireTime time.Duration, body string) signer {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
return signer{
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 300*time.Second, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-meta-other-header;x-amz-target"
expectedSig := "5eeedebf6f995145ce56daa02902d10485246d3defb34f97b973c1f40ab82d36"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
q := signer.Request.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 0, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-meta-other-header;x-amz-security-token;x-amz-target, Signature=69ada33fec48180dab153576e4dd80c4e04124f80dda3eccfed8a67c2b91ed5e"
q := signer.Request.Header
assert.Equal(t, expectedSig, q.Get("Authorization"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignEmptyBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "")
signer.Body = nil
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash)
}
func TestSignBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestSignSeekedBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, " hello")
signer.Body.Read(make([]byte, 3)) // consume first 3 bytes so body is now "hello"
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
start, _ := signer.Body.Seek(0, 1)
assert.Equal(t, int64(3), start)
}
func TestPresignEmptyBodyS3(t *testing.T) {
signer := buildSigner("s3", "us-east-1", time.Now(), 5*time.Minute, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.Request.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "PRECOMPUTED", hash)
}
func TestAnonymousCredentials(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: credentials.AnonymousCredentials}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
urlQ := r.HTTPRequest.URL.Query()
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
assert.Empty(t, urlQ.Get("X-Amz-Date"))
hQ := r.HTTPRequest.Header
assert.Empty(t, hQ.Get("Authorization"))
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: "us-west-2",
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
sig := r.HTTPRequest.Header.Get("Authorization")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: "us-west-2",
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
sig := r.HTTPRequest.Header.Get("X-Amz-Signature")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("X-Amz-Signature"))
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
creds.Expire()
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{credentials.Value{"AKID", "SECRET", "SESSION"}}
creds := credentials.NewCredentials(provider)
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
creds.Expire()
r.Time = time.Now().Add(time.Hour * 48)
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 300*time.Second, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}

File diff suppressed because it is too large Load diff

View file

@ -7,11 +7,12 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
var reBucketLocation = regexp.MustCompile(`>([^<>]+)<\/Location`)
func buildGetBucketLocation(r *aws.Request) {
func buildGetBucketLocation(r *request.Request) {
if r.DataFilled() {
out := r.Data.(*GetBucketLocationOutput)
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
@ -28,14 +29,14 @@ func buildGetBucketLocation(r *aws.Request) {
}
}
func populateLocationConstraint(r *aws.Request) {
if r.ParamsFilled() && r.Config.Region != "us-east-1" {
func populateLocationConstraint(r *request.Request) {
if r.ParamsFilled() && aws.StringValue(r.Service.Config.Region) != "us-east-1" {
in := r.Params.(*CreateBucketInput)
if in.CreateBucketConfiguration == nil {
r.Params = awsutil.CopyOf(r.Params)
in = r.Params.(*CreateBucketInput)
in.CreateBucketConfiguration = &CreateBucketConfiguration{
LocationConstraint: &r.Config.Region,
LocationConstraint: r.Service.Config.Region,
}
}
}

View file

@ -1,75 +0,0 @@
package s3_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
var s3LocationTests = []struct {
body string
loc string
}{
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/"/>`, ``},
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">EU</LocationConstraint>`, `EU`},
}
func TestGetBucketLocation(t *testing.T) {
for _, test := range s3LocationTests {
s := s3.New(nil)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *aws.Request) {
reader := ioutil.NopCloser(bytes.NewReader([]byte(test.body)))
r.HTTPResponse = &http.Response{StatusCode: 200, Body: reader}
})
resp, err := s.GetBucketLocation(&s3.GetBucketLocationInput{Bucket: aws.String("bucket")})
assert.NoError(t, err)
if test.loc == "" {
assert.Nil(t, resp.LocationConstraint)
} else {
assert.Equal(t, test.loc, *resp.LocationConstraint)
}
}
}
func TestPopulateLocationConstraint(t *testing.T) {
s := s3.New(nil)
in := &s3.CreateBucketInput{
Bucket: aws.String("bucket"),
}
req, _ := s.CreateBucketRequest(in)
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "mock-region", awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")[0])
assert.Nil(t, in.CreateBucketConfiguration) // don't modify original params
}
func TestNoPopulateLocationConstraintIfProvided(t *testing.T) {
s := s3.New(nil)
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
CreateBucketConfiguration: &s3.CreateBucketConfiguration{},
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, 0, len(awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")))
}
func TestNoPopulateLocationConstraintIfClassic(t *testing.T) {
s := s3.New(&aws.Config{Region: "us-east-1"})
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, 0, len(awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")))
}

View file

@ -5,13 +5,13 @@ import (
"encoding/base64"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// contentMD5 computes and sets the HTTP Content-MD5 header for requests that
// require it.
func contentMD5(r *aws.Request) {
func contentMD5(r *request.Request) {
h := md5.New()
// hash the body. seek back to the first position after reading to reset

View file

@ -1,9 +1,12 @@
package s3
import "github.com/aws/aws-sdk-go/aws"
import (
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
)
func init() {
initService = func(s *aws.Service) {
initService = func(s *service.Service) {
// Support building custom host-style bucket endpoints
s.Handlers.Build.PushFront(updateHostWithBucket)
@ -16,9 +19,9 @@ func init() {
s.Handlers.UnmarshalError.PushBack(unmarshalError)
}
initRequest = func(r *aws.Request) {
initRequest = func(r *request.Request) {
switch r.Operation.Name {
case opPutBucketCORS, opPutBucketLifecycle, opPutBucketPolicy, opPutBucketTagging, opDeleteObjects:
case opPutBucketCors, opPutBucketLifecycle, opPutBucketPolicy, opPutBucketTagging, opDeleteObjects:
// These S3 operations require Content-MD5 to be set
r.Handlers.Build.PushBack(contentMD5)
case opGetBucketLocation:
@ -27,6 +30,8 @@ func init() {
case opCreateBucket:
// Auto-populate LocationConstraint with current region
r.Handlers.Validate.PushFront(populateLocationConstraint)
case opCopyObject, opUploadPartCopy, opCompleteMultipartUpload:
r.Handlers.Unmarshal.PushFront(copyMultipartStatusOKUnmarhsalError)
}
}
}

View file

@ -1,90 +0,0 @@
package s3_test
import (
"crypto/md5"
"encoding/base64"
"io/ioutil"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func assertMD5(t *testing.T, req *aws.Request) {
err := req.Build()
assert.NoError(t, err)
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
out := md5.Sum(b)
assert.NotEmpty(t, b)
assert.Equal(t, base64.StdEncoding.EncodeToString(out[:]), req.HTTPRequest.Header.Get("Content-MD5"))
}
func TestMD5InPutBucketCORS(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketCORSRequest(&s3.PutBucketCORSInput{
Bucket: aws.String("bucketname"),
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{
{AllowedMethods: []*string{aws.String("GET")}},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketLifecycle(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketLifecycleRequest(&s3.PutBucketLifecycleInput{
Bucket: aws.String("bucketname"),
LifecycleConfiguration: &s3.LifecycleConfiguration{
Rules: []*s3.LifecycleRule{
{
ID: aws.String("ID"),
Prefix: aws.String("Prefix"),
Status: aws.String("Enabled"),
},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketPolicy(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketPolicyRequest(&s3.PutBucketPolicyInput{
Bucket: aws.String("bucketname"),
Policy: aws.String("{}"),
})
assertMD5(t, req)
}
func TestMD5InPutBucketTagging(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketTaggingRequest(&s3.PutBucketTaggingInput{
Bucket: aws.String("bucketname"),
Tagging: &s3.Tagging{
TagSet: []*s3.Tag{
{Key: aws.String("KEY"), Value: aws.String("VALUE")},
},
},
})
assertMD5(t, req)
}
func TestMD5InDeleteObjects(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.DeleteObjectsRequest(&s3.DeleteObjectsInput{
Bucket: aws.String("bucketname"),
Delete: &s3.Delete{
Objects: []*s3.ObjectIdentifier{
{Key: aws.String("key")},
},
},
})
assertMD5(t, req)
}

File diff suppressed because it is too large Load diff

View file

@ -6,6 +6,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
var reDomain = regexp.MustCompile(`^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$`)
@ -22,8 +23,8 @@ func dnsCompatibleBucketName(bucket string) bool {
// hostStyleBucketName returns true if the request should put the bucket in
// the host. This is false if S3ForcePathStyle is explicitly set or if the
// bucket is not DNS compatible.
func hostStyleBucketName(r *aws.Request, bucket string) bool {
if r.Config.S3ForcePathStyle {
func hostStyleBucketName(r *request.Request, bucket string) bool {
if aws.BoolValue(r.Service.Config.S3ForcePathStyle) {
return false
}
@ -33,11 +34,17 @@ func hostStyleBucketName(r *aws.Request, bucket string) bool {
return false
}
// GetBucketLocation should be able to be called from any region within
// a partition, and return the associated region of the bucket.
if r.Operation.Name == opGetBucketLocation {
return false
}
// Use host-style if the bucket is DNS compatible
return dnsCompatibleBucketName(bucket)
}
func updateHostWithBucket(r *aws.Request) {
func updateHostWithBucket(r *request.Request) {
b := awsutil.ValuesAtPath(r.Params, "Bucket")
if len(b) == 0 {
return

View file

@ -1,61 +0,0 @@
package s3_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
type s3BucketTest struct {
bucket string
url string
}
var (
_ = unit.Imported
sslTests = []s3BucketTest{
{"abc", "https://abc.s3.mock-region.amazonaws.com/"},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c"},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c"},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc"},
}
nosslTests = []s3BucketTest{
{"a.b.c", "http://a.b.c.s3.mock-region.amazonaws.com/"},
{"a..bc", "http://s3.mock-region.amazonaws.com/a..bc"},
}
forcepathTests = []s3BucketTest{
{"abc", "https://s3.mock-region.amazonaws.com/abc"},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c"},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c"},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc"},
}
)
func runTests(t *testing.T, svc *s3.S3, tests []s3BucketTest) {
for _, test := range tests {
req, _ := svc.ListObjectsRequest(&s3.ListObjectsInput{Bucket: &test.bucket})
req.Build()
assert.Equal(t, test.url, req.HTTPRequest.URL.String())
}
}
func TestHostStyleBucketBuild(t *testing.T) {
s := s3.New(nil)
runTests(t, s, sslTests)
}
func TestHostStyleBucketBuildNoSSL(t *testing.T) {
s := s3.New(&aws.Config{DisableSSL: true})
runTests(t, s, nosslTests)
}
func TestPathStyleBucketBuild(t *testing.T) {
s := s3.New(&aws.Config{S3ForcePathStyle: true})
runTests(t, s, forcepathTests)
}

View file

@ -1,119 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
// Package s3iface provides an interface for the Amazon Simple Storage Service.
package s3iface
import (
"github.com/aws/aws-sdk-go/service/s3"
)
// S3API is the interface type for s3.S3.
type S3API interface {
AbortMultipartUpload(*s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUpload(*s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error)
CopyObject(*s3.CopyObjectInput) (*s3.CopyObjectOutput, error)
CreateBucket(*s3.CreateBucketInput) (*s3.CreateBucketOutput, error)
CreateMultipartUpload(*s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
DeleteBucket(*s3.DeleteBucketInput) (*s3.DeleteBucketOutput, error)
DeleteBucketCORS(*s3.DeleteBucketCORSInput) (*s3.DeleteBucketCORSOutput, error)
DeleteBucketLifecycle(*s3.DeleteBucketLifecycleInput) (*s3.DeleteBucketLifecycleOutput, error)
DeleteBucketPolicy(*s3.DeleteBucketPolicyInput) (*s3.DeleteBucketPolicyOutput, error)
DeleteBucketReplication(*s3.DeleteBucketReplicationInput) (*s3.DeleteBucketReplicationOutput, error)
DeleteBucketTagging(*s3.DeleteBucketTaggingInput) (*s3.DeleteBucketTaggingOutput, error)
DeleteBucketWebsite(*s3.DeleteBucketWebsiteInput) (*s3.DeleteBucketWebsiteOutput, error)
DeleteObject(*s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error)
DeleteObjects(*s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error)
GetBucketACL(*s3.GetBucketACLInput) (*s3.GetBucketACLOutput, error)
GetBucketCORS(*s3.GetBucketCORSInput) (*s3.GetBucketCORSOutput, error)
GetBucketLifecycle(*s3.GetBucketLifecycleInput) (*s3.GetBucketLifecycleOutput, error)
GetBucketLocation(*s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error)
GetBucketLogging(*s3.GetBucketLoggingInput) (*s3.GetBucketLoggingOutput, error)
GetBucketNotification(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfigurationDeprecated, error)
GetBucketNotificationConfiguration(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfiguration, error)
GetBucketPolicy(*s3.GetBucketPolicyInput) (*s3.GetBucketPolicyOutput, error)
GetBucketReplication(*s3.GetBucketReplicationInput) (*s3.GetBucketReplicationOutput, error)
GetBucketRequestPayment(*s3.GetBucketRequestPaymentInput) (*s3.GetBucketRequestPaymentOutput, error)
GetBucketTagging(*s3.GetBucketTaggingInput) (*s3.GetBucketTaggingOutput, error)
GetBucketVersioning(*s3.GetBucketVersioningInput) (*s3.GetBucketVersioningOutput, error)
GetBucketWebsite(*s3.GetBucketWebsiteInput) (*s3.GetBucketWebsiteOutput, error)
GetObject(*s3.GetObjectInput) (*s3.GetObjectOutput, error)
GetObjectACL(*s3.GetObjectACLInput) (*s3.GetObjectACLOutput, error)
GetObjectTorrent(*s3.GetObjectTorrentInput) (*s3.GetObjectTorrentOutput, error)
HeadBucket(*s3.HeadBucketInput) (*s3.HeadBucketOutput, error)
HeadObject(*s3.HeadObjectInput) (*s3.HeadObjectOutput, error)
ListBuckets(*s3.ListBucketsInput) (*s3.ListBucketsOutput, error)
ListMultipartUploads(*s3.ListMultipartUploadsInput) (*s3.ListMultipartUploadsOutput, error)
ListObjectVersions(*s3.ListObjectVersionsInput) (*s3.ListObjectVersionsOutput, error)
ListObjects(*s3.ListObjectsInput) (*s3.ListObjectsOutput, error)
ListParts(*s3.ListPartsInput) (*s3.ListPartsOutput, error)
PutBucketACL(*s3.PutBucketACLInput) (*s3.PutBucketACLOutput, error)
PutBucketCORS(*s3.PutBucketCORSInput) (*s3.PutBucketCORSOutput, error)
PutBucketLifecycle(*s3.PutBucketLifecycleInput) (*s3.PutBucketLifecycleOutput, error)
PutBucketLogging(*s3.PutBucketLoggingInput) (*s3.PutBucketLoggingOutput, error)
PutBucketNotification(*s3.PutBucketNotificationInput) (*s3.PutBucketNotificationOutput, error)
PutBucketNotificationConfiguration(*s3.PutBucketNotificationConfigurationInput) (*s3.PutBucketNotificationConfigurationOutput, error)
PutBucketPolicy(*s3.PutBucketPolicyInput) (*s3.PutBucketPolicyOutput, error)
PutBucketReplication(*s3.PutBucketReplicationInput) (*s3.PutBucketReplicationOutput, error)
PutBucketRequestPayment(*s3.PutBucketRequestPaymentInput) (*s3.PutBucketRequestPaymentOutput, error)
PutBucketTagging(*s3.PutBucketTaggingInput) (*s3.PutBucketTaggingOutput, error)
PutBucketVersioning(*s3.PutBucketVersioningInput) (*s3.PutBucketVersioningOutput, error)
PutBucketWebsite(*s3.PutBucketWebsiteInput) (*s3.PutBucketWebsiteOutput, error)
PutObject(*s3.PutObjectInput) (*s3.PutObjectOutput, error)
PutObjectACL(*s3.PutObjectACLInput) (*s3.PutObjectACLOutput, error)
RestoreObject(*s3.RestoreObjectInput) (*s3.RestoreObjectOutput, error)
UploadPart(*s3.UploadPartInput) (*s3.UploadPartOutput, error)
UploadPartCopy(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error)
}

View file

@ -1,15 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package s3iface_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/stretchr/testify/assert"
)
func TestInterface(t *testing.T) {
assert.Implements(t, (*s3iface.S3API)(nil), s3.New(nil))
}

View file

@ -1,257 +0,0 @@
package s3manager
import (
"fmt"
"io"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/s3"
)
// The default range of bytes to get at a time when using Download().
var DefaultDownloadPartSize int64 = 1024 * 1024 * 5
// The default number of goroutines to spin up when using Download().
var DefaultDownloadConcurrency = 5
// The default set of options used when opts is nil in Download().
var DefaultDownloadOptions = &DownloadOptions{
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
// DownloadOptions keeps tracks of extra options to pass to an Download() call.
type DownloadOptions struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultConcurrency value will be used.
Concurrency int
// An S3 client to use when performing downloads. Leave this as nil to use
// a default client.
S3 *s3.S3
}
// NewDownloader creates a new Downloader structure that downloads an object
// from S3 in concurrent chunks. Pass in an optional DownloadOptions struct
// to customize the downloader behavior.
func NewDownloader(opts *DownloadOptions) *Downloader {
if opts == nil {
opts = DefaultDownloadOptions
}
return &Downloader{opts: opts}
}
// The Downloader structure that calls Download(). It is safe to call Download()
// on this structure for multiple objects and across concurrent goroutines.
type Downloader struct {
opts *DownloadOptions
}
// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests.
//
// It is safe to call this method for multiple objects and across concurrent
// goroutines.
func (d *Downloader) Download(w io.WriterAt, input *s3.GetObjectInput) (n int64, err error) {
impl := downloader{w: w, in: input, opts: *d.opts}
return impl.download()
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
opts DownloadOptions
in *s3.GetObjectInput
w io.WriterAt
wg sync.WaitGroup
m sync.Mutex
pos int64
totalBytes int64
written int64
err error
}
// init initializes the downloader with default options.
func (d *downloader) init() {
d.totalBytes = -1
if d.opts.Concurrency == 0 {
d.opts.Concurrency = DefaultDownloadConcurrency
}
if d.opts.PartSize == 0 {
d.opts.PartSize = DefaultDownloadPartSize
}
if d.opts.S3 == nil {
d.opts.S3 = s3.New(nil)
}
}
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
d.init()
// Spin up workers
ch := make(chan dlchunk, d.opts.Concurrency)
for i := 0; i < d.opts.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Assign work
for d.geterr() == nil {
if d.pos != 0 {
// This is not the first chunk, let's wait until we know the total
// size of the payload so we can see if we have read the entire
// object.
total := d.getTotalBytes()
if total < 0 {
// Total has not yet been set, so sleep and loop around while
// waiting for our first worker to resolve this value.
time.Sleep(10 * time.Millisecond)
continue
} else if d.pos >= total {
break // We're finished queueing chunks
}
}
// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.opts.PartSize}
d.pos += d.opts.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
// Return error
return d.written, d.err
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing a GetObject request on the data with a given byte range.
//
// If this is the first worker, this operation also resolves the total number
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok {
break
}
if d.geterr() == nil {
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
resp, err := d.opts.S3.GetObject(in)
if err != nil {
d.seterr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.
n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()
if err != nil {
d.seterr(err)
}
d.incrwritten(n)
}
}
}
}
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
defer d.m.Unlock()
return d.totalBytes
}
// getTotalBytes is a thread-safe setter for setting the total byte status.
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
d.m.Lock()
defer d.m.Unlock()
if d.totalBytes >= 0 {
return
}
parts := strings.Split(*resp.ContentRange, "/")
total, err := strconv.ParseInt(parts[len(parts)-1], 10, 64)
if err != nil {
d.err = err
return
}
d.totalBytes = total
}
func (d *downloader) incrwritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// geterr is a thread-safe getter for the error object
func (d *downloader) geterr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// seterr is a thread-safe setter for the error object
func (d *downloader) seterr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// dlchunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
w io.WriterAt
start int64
size int64
cur int64
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size {
return 0, io.EOF
}
n, err = c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)
return
}

View file

@ -1,165 +0,0 @@
package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"regexp"
"strconv"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
svc := s3.New(nil)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *aws.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++
if fin > int64(len(data)) {
fin = int64(len(data))
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(data[start:fin])),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
start, fin, len(data)))
})
return svc, &names, &ranges
}
type dlwriter struct {
buf []byte
}
func newDLWriter(size int) *dlwriter {
return &dlwriter{buf: make([]byte, size)}
}
func (d dlwriter) WriteAt(p []byte, pos int64) (n int, err error) {
if pos > int64(len(d.buf)) {
return 0, io.EOF
}
written := 0
for i, b := range p {
if i >= len(d.buf) {
break
}
d.buf[pos+int64(i)] = b
written++
}
return written, nil
}
func TestDownloadOrder(t *testing.T) {
s, names, ranges := dlLoggingSvc(buf12MB)
opts := &s3manager.DownloadOptions{S3: s, Concurrency: 1}
d := s3manager.NewDownloader(opts)
w := newDLWriter(len(buf12MB))
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf12MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}, *ranges)
count := 0
for _, b := range w.buf {
count += int(b)
}
assert.Equal(t, 0, count)
}
func TestDownloadZero(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{})
opts := &s3manager.DownloadOptions{S3: s}
d := s3manager.NewDownloader(opts)
w := newDLWriter(0)
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879"}, *ranges)
}
func TestDownloadSetPartSize(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{1, 2, 3})
opts := &s3manager.DownloadOptions{S3: s, PartSize: 1, Concurrency: 1}
d := s3manager.NewDownloader(opts)
w := newDLWriter(3)
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}, *ranges)
assert.Equal(t, []byte{1, 2, 3}, w.buf)
}
func TestDownloadError(t *testing.T) {
s, names, _ := dlLoggingSvc([]byte{1, 2, 3})
opts := &s3manager.DownloadOptions{S3: s, PartSize: 1, Concurrency: 1}
num := 0
s.Handlers.Send.PushBack(func(r *aws.Request) {
num++
if num > 1 {
r.HTTPResponse.StatusCode = 400
r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
}
})
d := s3manager.NewDownloader(opts)
w := newDLWriter(3)
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.NotNil(t, err)
assert.Equal(t, int64(1), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
assert.Equal(t, []byte{1, 0, 0}, w.buf)
}

View file

@ -1,562 +0,0 @@
package s3manager
import (
"bytes"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/s3"
)
// The maximum allowed number of parts in a multi-part upload on Amazon S3.
var MaxUploadParts = 10000
// The minimum allowed part size when uploading a part to Amazon S3.
var MinUploadPartSize int64 = 1024 * 1024 * 5
// The default part size to buffer chunks of a payload into.
var DefaultUploadPartSize = MinUploadPartSize
// The default number of goroutines to spin up when using Upload().
var DefaultUploadConcurrency = 5
// The default set of options used when opts is nil in Upload().
var DefaultUploadOptions = &UploadOptions{
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
S3: nil,
}
// A MultiUploadFailure wraps a failed S3 multipart upload. An error returned
// will satisfy this interface when a multi part upload failed to upload all
// chucks to S3. In the case of a failure the UploadID is needed to operate on
// the chunks, if any, which were uploaded.
//
// Example:
//
// u := s3manager.NewUploader(opts)
// output, err := u.upload(input)
// if err != nil {
// if multierr, ok := err.(MultiUploadFailure); ok {
// // Process error and its associated uploadID
// fmt.Println("Error:", multierr.Code(), multierr.Message(), multierr.UploadID())
// } else {
// // Process error generically
// fmt.Println("Error:", err.Error())
// }
// }
//
type MultiUploadFailure interface {
awserr.Error
// Returns the upload id for the S3 multipart upload that failed.
UploadID() string
}
// So that the Error interface type can be included as an anonymous field
// in the multiUploadError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A multiUploadError wraps the upload ID of a failed s3 multipart upload.
// Composed of BaseError for code, message, and original error
//
// Should be used for an error that occurred failing a S3 multipart upload,
// and a upload ID is available. If an uploadID is not available a more relevant
type multiUploadError struct {
awsError
// ID for multipart upload which failed.
uploadID string
}
// Error returns the string representation of the error.
//
// See apierr.BaseError ErrorWithExtra for output format
//
// Satisfies the error interface.
func (m multiUploadError) Error() string {
extra := fmt.Sprintf("upload id: %s", m.uploadID)
return awserr.SprintError(m.Code(), m.Message(), extra, m.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (m multiUploadError) String() string {
return m.Error()
}
// UploadID returns the id of the S3 upload which failed.
func (m multiUploadError) UploadID() string {
return m.uploadID
}
// UploadInput contains all input for upload requests to Amazon S3.
type UploadInput struct {
// The canned ACL to apply to the object.
ACL *string `location:"header" locationName:"x-amz-acl" type:"string"`
Bucket *string `location:"uri" locationName:"Bucket" type:"string" required:"true"`
// Specifies caching behavior along the request/reply chain.
CacheControl *string `location:"header" locationName:"Cache-Control" type:"string"`
// Specifies presentational information for the object.
ContentDisposition *string `location:"header" locationName:"Content-Disposition" type:"string"`
// Specifies what content encodings have been applied to the object and thus
// what decoding mechanisms must be applied to obtain the media-type referenced
// by the Content-Type header field.
ContentEncoding *string `location:"header" locationName:"Content-Encoding" type:"string"`
// The language the content is in.
ContentLanguage *string `location:"header" locationName:"Content-Language" type:"string"`
// A standard MIME type describing the format of the object data.
ContentType *string `location:"header" locationName:"Content-Type" type:"string"`
// The date and time at which the object is no longer cacheable.
Expires *time.Time `location:"header" locationName:"Expires" type:"timestamp" timestampFormat:"rfc822"`
// Gives the grantee READ, READ_ACP, and WRITE_ACP permissions on the object.
GrantFullControl *string `location:"header" locationName:"x-amz-grant-full-control" type:"string"`
// Allows grantee to read the object data and its metadata.
GrantRead *string `location:"header" locationName:"x-amz-grant-read" type:"string"`
// Allows grantee to read the object ACL.
GrantReadACP *string `location:"header" locationName:"x-amz-grant-read-acp" type:"string"`
// Allows grantee to write the ACL for the applicable object.
GrantWriteACP *string `location:"header" locationName:"x-amz-grant-write-acp" type:"string"`
Key *string `location:"uri" locationName:"Key" type:"string" required:"true"`
// A map of metadata to store with the object in S3.
Metadata map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
// Confirms that the requester knows that she or he will be charged for the
// request. Bucket owners need not specify this parameter in their requests.
// Documentation on downloading objects from requester pays buckets can be found
// at http://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html
RequestPayer *string `location:"header" locationName:"x-amz-request-payer" type:"string"`
// Specifies the algorithm to use to when encrypting the object (e.g., AES256,
// aws:kms).
SSECustomerAlgorithm *string `location:"header" locationName:"x-amz-server-side-encryption-customer-algorithm" type:"string"`
// Specifies the customer-provided encryption key for Amazon S3 to use in encrypting
// data. This value is used to store the object and then it is discarded; Amazon
// does not store the encryption key. The key must be appropriate for use with
// the algorithm specified in the x-amz-server-side-encryption-customer-algorithm
// header.
SSECustomerKey *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key" type:"string"`
// Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321.
// Amazon S3 uses this header for a message integrity check to ensure the encryption
// key was transmitted without error.
SSECustomerKeyMD5 *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key-MD5" type:"string"`
// Specifies the AWS KMS key ID to use for object encryption. All GET and PUT
// requests for an object protected by AWS KMS will fail if not made via SSL
// or using SigV4. Documentation on configuring any of the officially supported
// AWS SDKs and CLI can be found at http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingAWSSDK.html#specify-signature-version
SSEKMSKeyID *string `location:"header" locationName:"x-amz-server-side-encryption-aws-kms-key-id" type:"string"`
// The Server-side encryption algorithm used when storing this object in S3
// (e.g., AES256, aws:kms).
ServerSideEncryption *string `location:"header" locationName:"x-amz-server-side-encryption" type:"string"`
// The type of storage to use for the object. Defaults to 'STANDARD'.
StorageClass *string `location:"header" locationName:"x-amz-storage-class" type:"string"`
// If the bucket is configured as a website, redirects requests for this object
// to another object in the same bucket or to an external URL. Amazon S3 stores
// the value of this header in the object metadata.
WebsiteRedirectLocation *string `location:"header" locationName:"x-amz-website-redirect-location" type:"string"`
// The readable body payload to send to S3.
Body io.Reader
}
// UploadOutput represents a response from the Upload() call.
type UploadOutput struct {
// The URL where the object was uploaded to.
Location string
// The ID for a multipart upload to S3. In the case of an error the error
// can be cast to the MultiUploadFailure interface to extract the upload ID.
UploadID string
}
// UploadOptions keeps tracks of extra options to pass to an Upload() call.
type UploadOptions struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultConcurrency value will be used.
Concurrency int
// Setting this value to true will cause the SDK to avoid calling
// AbortMultipartUpload on a failure, leaving all successfully uploaded
// parts on S3 for manual recovery.
//
// Note that storing parts of an incomplete multipart upload counts towards
// space usage on S3 and will add additional costs if not cleaned up.
LeavePartsOnError bool
// The client to use when uploading to S3. Leave this as nil to use the
// default S3 client.
S3 *s3.S3
}
// NewUploader creates a new Uploader object to upload data to S3. Pass in
// an optional opts structure to customize the uploader behavior.
func NewUploader(opts *UploadOptions) *Uploader {
if opts == nil {
opts = DefaultUploadOptions
}
return &Uploader{opts: opts}
}
// The Uploader structure that calls Upload(). It is safe to call Upload()
// on this structure for multiple objects and across concurrent goroutines.
type Uploader struct {
opts *UploadOptions
}
// Upload uploads an object to S3, intelligently buffering large files into
// smaller chunks and sending them in parallel across multiple goroutines. You
// can configure the buffer size and concurrency through the opts parameter.
//
// If opts is set to nil, DefaultUploadOptions will be used.
//
// It is safe to call this method for multiple objects and across concurrent
// goroutines.
func (u *Uploader) Upload(input *UploadInput) (*UploadOutput, error) {
i := uploader{in: input, opts: *u.opts}
return i.upload()
}
// internal structure to manage an upload to S3.
type uploader struct {
in *UploadInput
opts UploadOptions
readerPos int64 // current reader position
totalSize int64 // set to -1 if the size is not known
}
// internal logic for deciding whether to upload a single part or use a
// multipart upload.
func (u *uploader) upload() (*UploadOutput, error) {
u.init()
if u.opts.PartSize < MinUploadPartSize {
msg := fmt.Sprintf("part size must be at least %d bytes", MinUploadPartSize)
return nil, awserr.New("ConfigError", msg, nil)
}
// Do one read to determine if we have more than one part
buf, err := u.nextReader()
if err == io.EOF || err == io.ErrUnexpectedEOF { // single part
return u.singlePart(buf)
} else if err != nil {
return nil, awserr.New("ReadRequestBody", "read upload data failed", err)
}
mu := multiuploader{uploader: u}
return mu.upload(buf)
}
// init will initialize all default options.
func (u *uploader) init() {
if u.opts.S3 == nil {
u.opts.S3 = s3.New(nil)
}
if u.opts.Concurrency == 0 {
u.opts.Concurrency = DefaultUploadConcurrency
}
if u.opts.PartSize == 0 {
u.opts.PartSize = DefaultUploadPartSize
}
// Try to get the total size for some optimizations
u.initSize()
}
// initSize tries to detect the total stream size, setting u.totalSize. If
// the size is not known, totalSize is set to -1.
func (u *uploader) initSize() {
u.totalSize = -1
switch r := u.in.Body.(type) {
case io.Seeker:
pos, _ := r.Seek(0, 1)
defer r.Seek(pos, 0)
n, err := r.Seek(0, 2)
if err != nil {
return
}
u.totalSize = n
// try to adjust partSize if it is too small
if u.totalSize/u.opts.PartSize >= int64(MaxUploadParts) {
u.opts.PartSize = u.totalSize / int64(MaxUploadParts)
}
}
}
// nextReader returns a seekable reader representing the next packet of data.
// This operation increases the shared u.readerPos counter, but note that it
// does not need to be wrapped in a mutex because nextReader is only called
// from the main thread.
func (u *uploader) nextReader() (io.ReadSeeker, error) {
switch r := u.in.Body.(type) {
case io.ReaderAt:
var err error
n := u.opts.PartSize
if u.totalSize >= 0 {
bytesLeft := u.totalSize - u.readerPos
if bytesLeft == 0 {
err = io.EOF
} else if bytesLeft <= u.opts.PartSize {
err = io.ErrUnexpectedEOF
n = bytesLeft
}
}
buf := io.NewSectionReader(r, u.readerPos, n)
u.readerPos += n
return buf, err
default:
packet := make([]byte, u.opts.PartSize)
n, err := io.ReadFull(u.in.Body, packet)
u.readerPos += int64(n)
return bytes.NewReader(packet[0:n]), err
}
}
// singlePart contains upload logic for uploading a single chunk via
// a regular PutObject request. Multipart requests require at least two
// parts, or at least 5MB of data.
func (u *uploader) singlePart(buf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.PutObjectInput{}
awsutil.Copy(params, u.in)
params.Body = buf
req, _ := u.opts.S3.PutObjectRequest(params)
if err := req.Send(); err != nil {
return nil, err
}
url := req.HTTPRequest.URL.String()
return &UploadOutput{Location: url}, nil
}
// internal structure to manage a specific multipart upload to S3.
type multiuploader struct {
*uploader
wg sync.WaitGroup
m sync.Mutex
err error
uploadID string
parts completedParts
}
// keeps track of a single chunk of data being sent to S3.
type chunk struct {
buf io.ReadSeeker
num int64
}
// completedParts is a wrapper to make parts sortable by their part number,
// since S3 required this list to be sent in sorted order.
type completedParts []*s3.CompletedPart
func (a completedParts) Len() int { return len(a) }
func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
// upload will perform a multipart upload using the firstBuf buffer containing
// the first chunk of data.
func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.CreateMultipartUploadInput{}
awsutil.Copy(params, u.in)
// Create the multipart
resp, err := u.opts.S3.CreateMultipartUpload(params)
if err != nil {
return nil, err
}
u.uploadID = *resp.UploadID
// Create the workers
ch := make(chan chunk, u.opts.Concurrency)
for i := 0; i < u.opts.Concurrency; i++ {
u.wg.Add(1)
go u.readChunk(ch)
}
// Send part 1 to the workers
var num int64 = 1
ch <- chunk{buf: firstBuf, num: num}
// Read and queue the rest of the parts
for u.geterr() == nil {
// This upload exceeded maximum number of supported parts, error now.
if num > int64(MaxUploadParts) {
msg := fmt.Sprintf("exceeded total allowed parts (%d). "+
"Adjust PartSize to fit in this limit", MaxUploadParts)
u.seterr(awserr.New("TotalPartsExceeded", msg, nil))
break
}
num++
buf, err := u.nextReader()
if err == io.EOF {
break
}
ch <- chunk{buf: buf, num: num}
if err != nil && err != io.ErrUnexpectedEOF {
u.seterr(awserr.New(
"ReadRequestBody",
"read multipart upload data failed",
err))
break
}
}
// Close the channel, wait for workers, and complete upload
close(ch)
u.wg.Wait()
complete := u.complete()
if err := u.geterr(); err != nil {
return nil, &multiUploadError{
awsError: awserr.New(
"MultipartUpload",
"upload multipart failed",
err),
uploadID: u.uploadID,
}
}
return &UploadOutput{
Location: *complete.Location,
UploadID: u.uploadID,
}, nil
}
// readChunk runs in worker goroutines to pull chunks off of the ch channel
// and send() them as UploadPart requests.
func (u *multiuploader) readChunk(ch chan chunk) {
defer u.wg.Done()
for {
data, ok := <-ch
if !ok {
break
}
if u.geterr() == nil {
if err := u.send(data); err != nil {
u.seterr(err)
}
}
}
}
// send performs an UploadPart request and keeps track of the completed
// part information.
func (u *multiuploader) send(c chunk) error {
resp, err := u.opts.S3.UploadPart(&s3.UploadPartInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
Body: c.buf,
UploadID: &u.uploadID,
PartNumber: &c.num,
})
if err != nil {
return err
}
n := c.num
completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &n}
u.m.Lock()
u.parts = append(u.parts, completed)
u.m.Unlock()
return nil
}
// geterr is a thread-safe getter for the error object
func (u *multiuploader) geterr() error {
u.m.Lock()
defer u.m.Unlock()
return u.err
}
// seterr is a thread-safe setter for the error object
func (u *multiuploader) seterr(e error) {
u.m.Lock()
defer u.m.Unlock()
u.err = e
}
// fail will abort the multipart unless LeavePartsOnError is set to true.
func (u *multiuploader) fail() {
if u.opts.LeavePartsOnError {
return
}
u.opts.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadID: &u.uploadID,
})
}
// complete successfully completes a multipart upload and returns the response.
func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput {
if u.geterr() != nil {
u.fail()
return nil
}
// Parts must be sorted in PartNumber order.
sort.Sort(u.parts)
resp, err := u.opts.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadID: &u.uploadID,
MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts},
})
if err != nil {
u.seterr(err)
u.fail()
}
return resp
}

View file

@ -1,438 +0,0 @@
package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"sort"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
var buf12MB = make([]byte, 1024*1024*12)
var buf2MB = make([]byte, 1024*1024*2)
var emptyList = []string{}
func val(i interface{}, s string) interface{} {
return awsutil.ValuesAtPath(i, s)[0]
}
func contains(src []string, s string) bool {
for _, v := range src {
if s == v {
return true
}
}
return false
}
func loggingSvc(ignoreOps []string) (*s3.S3, *[]string, *[]interface{}) {
var m sync.Mutex
partNum := 0
names := []string{}
params := []interface{}{}
svc := s3.New(nil)
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *aws.Request) {
m.Lock()
defer m.Unlock()
if !contains(ignoreOps, r.Operation.Name) {
names = append(names, r.Operation.Name)
params = append(params, r.Params)
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
switch data := r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
data.UploadID = aws.String("UPLOAD-ID")
case *s3.UploadPartOutput:
partNum++
data.ETag = aws.String(fmt.Sprintf("ETAG%d", partNum))
case *s3.CompleteMultipartUploadOutput:
data.Location = aws.String("https://location")
}
})
return svc, &names, &params
}
func buflen(i interface{}) int {
r := i.(io.Reader)
b, _ := ioutil.ReadAll(r)
return len(b)
}
func TestUploadOrderMulti(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
ServerSideEncryption: aws.String("AES256"),
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
assert.Equal(t, "https://location", resp.Location)
assert.Equal(t, "UPLOAD-ID", resp.UploadID)
// Validate input values
// UploadPart
assert.Equal(t, "UPLOAD-ID", val((*args)[1], "UploadID"))
assert.Equal(t, "UPLOAD-ID", val((*args)[2], "UploadID"))
assert.Equal(t, "UPLOAD-ID", val((*args)[3], "UploadID"))
// CompleteMultipartUpload
assert.Equal(t, "UPLOAD-ID", val((*args)[4], "UploadID"))
assert.Equal(t, int64(1), val((*args)[4], "MultipartUpload.Parts[0].PartNumber"))
assert.Equal(t, int64(2), val((*args)[4], "MultipartUpload.Parts[1].PartNumber"))
assert.Equal(t, int64(3), val((*args)[4], "MultipartUpload.Parts[2].PartNumber"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[0].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[1].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[2].ETag"))
// Custom headers
assert.Equal(t, "AES256", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
}
func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{
S3: s,
PartSize: 1024 * 1024 * 7,
Concurrency: 1,
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
assert.Equal(t, 1024*1024*7, buflen(val((*args)[1], "Body")))
assert.Equal(t, 1024*1024*5, buflen(val((*args)[2], "Body")))
}
func TestUploadIncreasePartSize(t *testing.T) {
s3manager.MaxUploadParts = 2
defer func() { s3manager.MaxUploadParts = 10000 }()
s, ops, args := loggingSvc(emptyList)
opts := &s3manager.UploadOptions{S3: s, Concurrency: 1}
mgr := s3manager.NewUploader(opts)
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, int64(0), opts.PartSize) // don't modify orig options
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
assert.Equal(t, 1024*1024*6, buflen(val((*args)[1], "Body")))
assert.Equal(t, 1024*1024*6, buflen(val((*args)[2], "Body")))
}
func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
opts := &s3manager.UploadOptions{PartSize: 5}
mgr := s3manager.NewUploader(opts)
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Nil(t, resp)
assert.NotNil(t, err)
aerr := err.(awserr.Error)
assert.Equal(t, "ConfigError", aerr.Code())
assert.Contains(t, aerr.Message(), "part size must be at least")
}
func TestUploadOrderSingle(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf2MB),
ServerSideEncryption: aws.String("AES256"),
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, "AES256", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
}
func TestUploadOrderSingleFailure(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
r.HTTPResponse.StatusCode = 400
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf2MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.Nil(t, resp)
}
func TestUploadOrderZero(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 0)),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, 0, buflen(val((*args)[0], "Body")))
}
func TestUploadOrderMultiFailure(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch t := r.Data.(type) {
case *s3.UploadPartOutput:
if *t.ETag == "ETAG2" {
r.HTTPResponse.StatusCode = 400
}
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch r.Data.(type) {
case *s3.CompleteMultipartUploadOutput:
r.HTTPResponse.StatusCode = 400
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload", "AbortMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
r.HTTPResponse.StatusCode = 400
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch data := r.Data.(type) {
case *s3.UploadPartOutput:
if *data.ETag == "ETAG2" {
r.HTTPResponse.StatusCode = 400
}
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{
S3: s,
Concurrency: 1,
LeavePartsOnError: true,
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *ops)
}
type failreader struct {
times int
failCount int
}
func (f *failreader) Read(b []byte) (int, error) {
f.failCount++
if f.failCount >= f.times {
return 0, fmt.Errorf("random failure")
}
return len(b), nil
}
func TestUploadOrderReadFail1(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 1},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
assert.EqualError(t, err.(awserr.Error).OrigErr(), "random failure")
assert.Equal(t, []string{}, *ops)
}
func TestUploadOrderReadFail2(t *testing.T) {
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 2},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
assert.EqualError(t, err.(awserr.Error).OrigErr(), "random failure")
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
}
type sizedReader struct {
size int
cur int
}
func (s *sizedReader) Read(p []byte) (n int, err error) {
if s.cur >= s.size {
return 0, io.EOF
}
n = len(p)
s.cur += len(p)
if s.cur > s.size {
n -= s.cur - s.size
}
return
}
func TestUploadOrderMultiBufferedReader(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
parts := []int{
buflen(val((*args)[1], "Body")),
buflen(val((*args)[2], "Body")),
buflen(val((*args)[3], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
}
func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
s3manager.MaxUploadParts = 2
defer func() { s3manager.MaxUploadParts = 10000 }()
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.Error(t, err)
assert.Nil(t, resp)
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
aerr := err.(awserr.Error)
assert.Equal(t, "TotalPartsExceeded", aerr.Code())
assert.Contains(t, aerr.Message(), "exceeded total allowed parts (2)")
}
func TestUploadOrderSingleBufferedReader(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 2},
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
}

View file

@ -4,27 +4,33 @@ package s3
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/aws/aws-sdk-go/internal/protocol/restxml"
"github.com/aws/aws-sdk-go/internal/signer/v4"
)
// S3 is a client for Amazon S3.
type S3 struct {
*aws.Service
*service.Service
}
// Used for custom service initialization logic
var initService func(*aws.Service)
var initService func(*service.Service)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
var initRequest func(*request.Request)
// New returns a new S3 client.
func New(config *aws.Config) *S3 {
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "s3",
APIVersion: "2006-03-01",
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "s3",
APIVersion: "2006-03-01",
},
}
service.Initialize()
@ -45,8 +51,8 @@ func New(config *aws.Config) *S3 {
// newRequest creates a new request for a S3 operation and runs any
// custom request initialization.
func (c *S3) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
func (c *S3) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
// Run custom request initialization if present
if initRequest != nil {

View file

@ -4,14 +4,14 @@ import (
"crypto/md5"
"encoding/base64"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
var errSSERequiresSSL = awserr.New("ConfigError", "cannot send SSE keys over HTTP.", nil)
func validateSSERequiresSSL(r *aws.Request) {
func validateSSERequiresSSL(r *request.Request) {
if r.HTTPRequest.URL.Scheme != "https" {
p := awsutil.ValuesAtPath(r.Params, "SSECustomerKey||CopySourceSSECustomerKey")
if len(p) > 0 {
@ -20,7 +20,7 @@ func validateSSERequiresSSL(r *aws.Request) {
}
}
func computeSSEKeys(r *aws.Request) {
func computeSSEKeys(r *request.Request) {
headers := []string{
"x-amz-server-side-encryption-customer-key",
"x-amz-copy-source-server-side-encryption-customer-key",

View file

@ -1,81 +0,0 @@
package s3_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(&aws.Config{DisableSSL: true})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.Error(t, err)
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
}
func TestCopySourceSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(&aws.Config{DisableSSL: true})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.Error(t, err)
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
}
func TestComputeSSEKeys(t *testing.T) {
s := s3.New(nil)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
}
func TestComputeSSEKeysShortcircuit(t *testing.T) {
s := s3.New(nil)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
SSECustomerKeyMD5: aws.String("MD5"),
CopySourceSSECustomerKeyMD5: aws.String("MD5"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
}

View file

@ -0,0 +1,36 @@
package s3
import (
"bytes"
"io/ioutil"
"net/http"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
func copyMultipartStatusOKUnmarhsalError(r *request.Request) {
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to read response body", err)
return
}
body := bytes.NewReader(b)
r.HTTPResponse.Body = aws.ReadSeekCloser(body)
defer r.HTTPResponse.Body.(aws.ReaderSeekerCloser).Seek(0, 0)
if body.Len() == 0 {
// If there is no body don't attempt to parse the body.
return
}
unmarshalError(r)
if err, ok := r.Error.(awserr.Error); ok && err != nil {
if err.Code() == "SerializationError" {
r.Error = nil
return
}
r.HTTPResponse.StatusCode = http.StatusServiceUnavailable
}
}

View file

@ -2,11 +2,14 @@ package s3
import (
"encoding/xml"
"fmt"
"io"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
type xmlErrorResponse struct {
@ -15,9 +18,15 @@ type xmlErrorResponse struct {
Message string `xml:"Message"`
}
func unmarshalError(r *aws.Request) {
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
if r.HTTPResponse.StatusCode == http.StatusMovedPermanently {
r.Error = awserr.New("BucketRegionError",
fmt.Sprintf("incorrect region, the bucket is not in '%s' region", aws.StringValue(r.Service.Config.Region)), nil)
return
}
if r.HTTPResponse.ContentLength == int64(0) {
// No body, use status code to generate an awserr.Error
r.Error = awserr.NewRequestFailure(

View file

@ -1,53 +0,0 @@
package s3_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
var s3StatusCodeErrorTests = []struct {
scode int
status string
body string
code string
message string
}{
{301, "Moved Permanently", "", "MovedPermanently", "Moved Permanently"},
{403, "Forbidden", "", "Forbidden", "Forbidden"},
{400, "Bad Request", "", "BadRequest", "Bad Request"},
{404, "Not Found", "", "NotFound", "Not Found"},
{500, "Internal Error", "", "InternalError", "Internal Error"},
}
func TestStatusCodeError(t *testing.T) {
for _, test := range s3StatusCodeErrorTests {
s := s3.New(nil)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *aws.Request) {
body := ioutil.NopCloser(bytes.NewReader([]byte(test.body)))
r.HTTPResponse = &http.Response{
ContentLength: int64(len(test.body)),
StatusCode: test.scode,
Status: test.status,
Body: body,
}
})
_, err := s.PutBucketACL(&s3.PutBucketACLInput{
Bucket: aws.String("bucket"), ACL: aws.String("public-read"),
})
assert.Error(t, err)
assert.Equal(t, test.code, err.(awserr.Error).Code())
assert.Equal(t, test.message, err.(awserr.Error).Message())
}
}

View file

@ -1 +0,0 @@
(Stub, TBA)

View file

@ -1,68 +0,0 @@
TODO:
* What are filters?
* List+explain all existing filters (pongo2 + pongo2-addons)
Implemented filters so far which needs documentation:
* escape
* safe
* escapejs
* add
* addslashes
* capfirst
* center
* cut
* date
* default
* default_if_none
* divisibleby
* first
* floatformat
* get_digit
* iriencode
* join
* last
* length
* length_is
* linebreaks
* linebreaksbr
* linenumbers
* ljust
* lower
* make_list
* phone2numeric
* pluralize
* random
* removetags
* rjust
* slice
* stringformat
* striptags
* time
* title
* truncatechars
* truncatechars_html
* truncatewords
* truncatewords_html
* upper
* urlencode
* urlize
* urlizetrunc
* wordcount
* wordwrap
* yesno
* filesizeformat*
* slugify*
* truncatesentences*
* truncatesentences_html*
* markdown*
* intcomma*
* ordinal*
* naturalday*
* timesince*
* timeuntil*
* naturaltime*
Filters marked with * are available through [pongo2-addons](https://github.com/flosch/pongo2-addons).

View file

@ -1 +0,0 @@
(Stub, TBA)

View file

@ -1 +0,0 @@
(Stub, TBA)

View file

@ -1,31 +0,0 @@
TODO:
* What are tags?
* List+explain all existing tags (pongo2 + pongo2-addons)
Implemented tags so far which needs documentation:
* autoescape
* block
* comment
* cycle
* extends
* filter
* firstof
* for
* if
* ifchanged
* ifequal
* ifnotequal
* import
* include
* lorem
* macro
* now
* set
* spaceless
* ssi
* templatetag
* verbatim
* widthratio
* with

View file

@ -1 +0,0 @@
(Stub, TBA)

View file

View file

@ -1,10 +0,0 @@
{{ "<script>alert('xss');</script>" }}
{% autoescape off %}
{{ "<script>alert('xss');</script>" }}
{% endautoescape %}
{% autoescape on %}
{{ "<script>alert('xss');</script>" }}
{% endautoescape %}
{% autoescape off %}
{{ "<script>alert('xss');</script>"|escape }}
{% endautoescape %}

Some files were not shown because too many files have changed in this diff Show more