mirror of
https://github.com/Luzifer/mondash.git
synced 2025-01-10 12:51:49 +00:00
257 lines
6.1 KiB
Go
257 lines
6.1 KiB
Go
package s3manager
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/aws/awsutil"
|
|
"github.com/aws/aws-sdk-go/service/s3"
|
|
)
|
|
|
|
// The default range of bytes to get at a time when using Download().
|
|
var DefaultDownloadPartSize int64 = 1024 * 1024 * 5
|
|
|
|
// The default number of goroutines to spin up when using Download().
|
|
var DefaultDownloadConcurrency = 5
|
|
|
|
// The default set of options used when opts is nil in Download().
|
|
var DefaultDownloadOptions = &DownloadOptions{
|
|
PartSize: DefaultDownloadPartSize,
|
|
Concurrency: DefaultDownloadConcurrency,
|
|
}
|
|
|
|
// DownloadOptions keeps tracks of extra options to pass to an Download() call.
|
|
type DownloadOptions struct {
|
|
// The buffer size (in bytes) to use when buffering data into chunks and
|
|
// sending them as parts to S3. The minimum allowed part size is 5MB, and
|
|
// if this value is set to zero, the DefaultPartSize value will be used.
|
|
PartSize int64
|
|
|
|
// The number of goroutines to spin up in parallel when sending parts.
|
|
// If this is set to zero, the DefaultConcurrency value will be used.
|
|
Concurrency int
|
|
|
|
// An S3 client to use when performing downloads. Leave this as nil to use
|
|
// a default client.
|
|
S3 *s3.S3
|
|
}
|
|
|
|
// NewDownloader creates a new Downloader structure that downloads an object
|
|
// from S3 in concurrent chunks. Pass in an optional DownloadOptions struct
|
|
// to customize the downloader behavior.
|
|
func NewDownloader(opts *DownloadOptions) *Downloader {
|
|
if opts == nil {
|
|
opts = DefaultDownloadOptions
|
|
}
|
|
return &Downloader{opts: opts}
|
|
}
|
|
|
|
// The Downloader structure that calls Download(). It is safe to call Download()
|
|
// on this structure for multiple objects and across concurrent goroutines.
|
|
type Downloader struct {
|
|
opts *DownloadOptions
|
|
}
|
|
|
|
// Download downloads an object in S3 and writes the payload into w using
|
|
// concurrent GET requests.
|
|
//
|
|
// It is safe to call this method for multiple objects and across concurrent
|
|
// goroutines.
|
|
func (d *Downloader) Download(w io.WriterAt, input *s3.GetObjectInput) (n int64, err error) {
|
|
impl := downloader{w: w, in: input, opts: *d.opts}
|
|
return impl.download()
|
|
}
|
|
|
|
// downloader is the implementation structure used internally by Downloader.
|
|
type downloader struct {
|
|
opts DownloadOptions
|
|
in *s3.GetObjectInput
|
|
w io.WriterAt
|
|
|
|
wg sync.WaitGroup
|
|
m sync.Mutex
|
|
|
|
pos int64
|
|
totalBytes int64
|
|
written int64
|
|
err error
|
|
}
|
|
|
|
// init initializes the downloader with default options.
|
|
func (d *downloader) init() {
|
|
d.totalBytes = -1
|
|
|
|
if d.opts.Concurrency == 0 {
|
|
d.opts.Concurrency = DefaultDownloadConcurrency
|
|
}
|
|
|
|
if d.opts.PartSize == 0 {
|
|
d.opts.PartSize = DefaultDownloadPartSize
|
|
}
|
|
|
|
if d.opts.S3 == nil {
|
|
d.opts.S3 = s3.New(nil)
|
|
}
|
|
}
|
|
|
|
// download performs the implementation of the object download across ranged
|
|
// GETs.
|
|
func (d *downloader) download() (n int64, err error) {
|
|
d.init()
|
|
|
|
// Spin up workers
|
|
ch := make(chan dlchunk, d.opts.Concurrency)
|
|
for i := 0; i < d.opts.Concurrency; i++ {
|
|
d.wg.Add(1)
|
|
go d.downloadPart(ch)
|
|
}
|
|
|
|
// Assign work
|
|
for d.geterr() == nil {
|
|
if d.pos != 0 {
|
|
// This is not the first chunk, let's wait until we know the total
|
|
// size of the payload so we can see if we have read the entire
|
|
// object.
|
|
total := d.getTotalBytes()
|
|
|
|
if total < 0 {
|
|
// Total has not yet been set, so sleep and loop around while
|
|
// waiting for our first worker to resolve this value.
|
|
time.Sleep(10 * time.Millisecond)
|
|
continue
|
|
} else if d.pos >= total {
|
|
break // We're finished queueing chunks
|
|
}
|
|
}
|
|
|
|
// Queue the next range of bytes to read.
|
|
ch <- dlchunk{w: d.w, start: d.pos, size: d.opts.PartSize}
|
|
d.pos += d.opts.PartSize
|
|
}
|
|
|
|
// Wait for completion
|
|
close(ch)
|
|
d.wg.Wait()
|
|
|
|
// Return error
|
|
return d.written, d.err
|
|
}
|
|
|
|
// downloadPart is an individual goroutine worker reading from the ch channel
|
|
// and performing a GetObject request on the data with a given byte range.
|
|
//
|
|
// If this is the first worker, this operation also resolves the total number
|
|
// of bytes to be read so that the worker manager knows when it is finished.
|
|
func (d *downloader) downloadPart(ch chan dlchunk) {
|
|
defer d.wg.Done()
|
|
|
|
for {
|
|
chunk, ok := <-ch
|
|
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
if d.geterr() == nil {
|
|
// Get the next byte range of data
|
|
in := &s3.GetObjectInput{}
|
|
awsutil.Copy(in, d.in)
|
|
rng := fmt.Sprintf("bytes=%d-%d",
|
|
chunk.start, chunk.start+chunk.size-1)
|
|
in.Range = &rng
|
|
|
|
resp, err := d.opts.S3.GetObject(in)
|
|
if err != nil {
|
|
d.seterr(err)
|
|
} else {
|
|
d.setTotalBytes(resp) // Set total if not yet set.
|
|
|
|
n, err := io.Copy(&chunk, resp.Body)
|
|
resp.Body.Close()
|
|
|
|
if err != nil {
|
|
d.seterr(err)
|
|
}
|
|
d.incrwritten(n)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
|
|
func (d *downloader) getTotalBytes() int64 {
|
|
d.m.Lock()
|
|
defer d.m.Unlock()
|
|
|
|
return d.totalBytes
|
|
}
|
|
|
|
// getTotalBytes is a thread-safe setter for setting the total byte status.
|
|
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
|
|
d.m.Lock()
|
|
defer d.m.Unlock()
|
|
|
|
if d.totalBytes >= 0 {
|
|
return
|
|
}
|
|
|
|
parts := strings.Split(*resp.ContentRange, "/")
|
|
total, err := strconv.ParseInt(parts[len(parts)-1], 10, 64)
|
|
if err != nil {
|
|
d.err = err
|
|
return
|
|
}
|
|
|
|
d.totalBytes = total
|
|
}
|
|
|
|
func (d *downloader) incrwritten(n int64) {
|
|
d.m.Lock()
|
|
defer d.m.Unlock()
|
|
|
|
d.written += n
|
|
}
|
|
|
|
// geterr is a thread-safe getter for the error object
|
|
func (d *downloader) geterr() error {
|
|
d.m.Lock()
|
|
defer d.m.Unlock()
|
|
|
|
return d.err
|
|
}
|
|
|
|
// seterr is a thread-safe setter for the error object
|
|
func (d *downloader) seterr(e error) {
|
|
d.m.Lock()
|
|
defer d.m.Unlock()
|
|
|
|
d.err = e
|
|
}
|
|
|
|
// dlchunk represents a single chunk of data to write by the worker routine.
|
|
// This structure also implements an io.SectionReader style interface for
|
|
// io.WriterAt, effectively making it an io.SectionWriter (which does not
|
|
// exist).
|
|
type dlchunk struct {
|
|
w io.WriterAt
|
|
start int64
|
|
size int64
|
|
cur int64
|
|
}
|
|
|
|
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
|
|
// position to its end (or EOF).
|
|
func (c *dlchunk) Write(p []byte) (n int, err error) {
|
|
if c.cur >= c.size {
|
|
return 0, io.EOF
|
|
}
|
|
|
|
n, err = c.w.WriteAt(p, c.start+c.cur)
|
|
c.cur += int64(n)
|
|
|
|
return
|
|
}
|