1
0
Fork 0
mirror of https://github.com/Luzifer/mondash.git synced 2025-01-10 12:51:49 +00:00
mondash/vendor/github.com/aws/aws-sdk-go/service/s3/s3manager/download.go

257 lines
6.1 KiB
Go

package s3manager
import (
"fmt"
"io"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/s3"
)
// The default range of bytes to get at a time when using Download().
var DefaultDownloadPartSize int64 = 1024 * 1024 * 5
// The default number of goroutines to spin up when using Download().
var DefaultDownloadConcurrency = 5
// The default set of options used when opts is nil in Download().
var DefaultDownloadOptions = &DownloadOptions{
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
// DownloadOptions keeps tracks of extra options to pass to an Download() call.
type DownloadOptions struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultConcurrency value will be used.
Concurrency int
// An S3 client to use when performing downloads. Leave this as nil to use
// a default client.
S3 *s3.S3
}
// NewDownloader creates a new Downloader structure that downloads an object
// from S3 in concurrent chunks. Pass in an optional DownloadOptions struct
// to customize the downloader behavior.
func NewDownloader(opts *DownloadOptions) *Downloader {
if opts == nil {
opts = DefaultDownloadOptions
}
return &Downloader{opts: opts}
}
// The Downloader structure that calls Download(). It is safe to call Download()
// on this structure for multiple objects and across concurrent goroutines.
type Downloader struct {
opts *DownloadOptions
}
// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests.
//
// It is safe to call this method for multiple objects and across concurrent
// goroutines.
func (d *Downloader) Download(w io.WriterAt, input *s3.GetObjectInput) (n int64, err error) {
impl := downloader{w: w, in: input, opts: *d.opts}
return impl.download()
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
opts DownloadOptions
in *s3.GetObjectInput
w io.WriterAt
wg sync.WaitGroup
m sync.Mutex
pos int64
totalBytes int64
written int64
err error
}
// init initializes the downloader with default options.
func (d *downloader) init() {
d.totalBytes = -1
if d.opts.Concurrency == 0 {
d.opts.Concurrency = DefaultDownloadConcurrency
}
if d.opts.PartSize == 0 {
d.opts.PartSize = DefaultDownloadPartSize
}
if d.opts.S3 == nil {
d.opts.S3 = s3.New(nil)
}
}
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
d.init()
// Spin up workers
ch := make(chan dlchunk, d.opts.Concurrency)
for i := 0; i < d.opts.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Assign work
for d.geterr() == nil {
if d.pos != 0 {
// This is not the first chunk, let's wait until we know the total
// size of the payload so we can see if we have read the entire
// object.
total := d.getTotalBytes()
if total < 0 {
// Total has not yet been set, so sleep and loop around while
// waiting for our first worker to resolve this value.
time.Sleep(10 * time.Millisecond)
continue
} else if d.pos >= total {
break // We're finished queueing chunks
}
}
// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.opts.PartSize}
d.pos += d.opts.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
// Return error
return d.written, d.err
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing a GetObject request on the data with a given byte range.
//
// If this is the first worker, this operation also resolves the total number
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok {
break
}
if d.geterr() == nil {
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
resp, err := d.opts.S3.GetObject(in)
if err != nil {
d.seterr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.
n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()
if err != nil {
d.seterr(err)
}
d.incrwritten(n)
}
}
}
}
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
defer d.m.Unlock()
return d.totalBytes
}
// getTotalBytes is a thread-safe setter for setting the total byte status.
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
d.m.Lock()
defer d.m.Unlock()
if d.totalBytes >= 0 {
return
}
parts := strings.Split(*resp.ContentRange, "/")
total, err := strconv.ParseInt(parts[len(parts)-1], 10, 64)
if err != nil {
d.err = err
return
}
d.totalBytes = total
}
func (d *downloader) incrwritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// geterr is a thread-safe getter for the error object
func (d *downloader) geterr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// seterr is a thread-safe setter for the error object
func (d *downloader) seterr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// dlchunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
w io.WriterAt
start int64
size int64
cur int64
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size {
return 0, io.EOF
}
n, err = c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)
return
}