1
0
Fork 0
mirror of https://github.com/Luzifer/share.git synced 2024-10-18 05:14:23 +00:00
share/upload.go

161 lines
3.8 KiB
Go
Raw Normal View History

package main
import (
"bytes"
"compress/gzip"
"crypto/sha256"
"fmt"
"html/template"
"io"
"mime"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/cheggaaa/pb/v3"
"github.com/gofrs/uuid"
"github.com/gosimple/slug"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
const barUpdateInterval = 100 * time.Millisecond
//revive:disable-next-line:flag-parameter // Fine in this case
func executeUpload(inFileName string, inFileHandle io.ReadSeeker, useCalculatedFilename bool, overrideMimeType string, forceGzip bool) (string, error) {
var (
upFile = inFileName
err error
)
if useCalculatedFilename {
if upFile, err = calculateUploadFilename(inFileName, inFileHandle); err != nil {
return "", errors.Wrap(err, "calculating upload filename")
}
}
mimeType := mime.TypeByExtension(path.Ext(upFile))
if mimeType == "" {
mimeType = "application/octet-stream"
}
if overrideMimeType != "" {
mimeType = overrideMimeType
}
log.Debugf("Uploading file to %q with type %q", upFile, mimeType)
var contentEncoding *string
if forceGzip {
buf := new(bytes.Buffer)
gw := gzip.NewWriter(buf)
if _, err := io.Copy(gw, inFileHandle); err != nil {
return "", errors.Wrap(err, "compressing file")
}
if err := gw.Close(); err != nil {
return "", errors.Wrap(err, "closing gzip writer")
}
inFileHandle = bytes.NewReader(buf.Bytes())
contentEncoding = aws.String("gzip")
}
var awsCfgs []*aws.Config
if cfg.Endpoint != "" {
awsCfgs = append(awsCfgs, &aws.Config{Endpoint: &cfg.Endpoint, S3ForcePathStyle: aws.Bool(true)})
}
sess := session.Must(session.NewSession(awsCfgs...))
svc := s3.New(sess)
ps, err := newProgressSeeker(inFileHandle)
if err != nil {
return "", err
}
if cfg.Progress {
bar := pb.New64(ps.Size)
bar.Set(pb.Bytes, true)
bar.Set("prefix", inFileName)
bar.Start()
barUpdate := true
go func() {
for barUpdate {
bar.SetCurrent(ps.Progress)
time.Sleep(barUpdateInterval)
}
}()
defer func() {
barUpdate = false
bar.SetCurrent(ps.Progress)
bar.Finish()
}()
}
if _, err = svc.PutObject(&s3.PutObjectInput{
Body: ps,
Bucket: aws.String(cfg.Bucket),
ContentEncoding: contentEncoding,
ContentType: aws.String(mimeType),
Key: aws.String(upFile),
}); err != nil {
return "", fmt.Errorf("putting object to S3: %w", err)
}
return fmt.Sprintf("%s%s", cfg.BaseURL, upFile), nil
}
func calculateUploadFilename(inFile string, inFileHandle io.ReadSeeker) (string, error) {
fileHash, err := hashFile(inFileHandle)
if err != nil {
return "", err
}
safeFileName := strings.Join([]string{
slug.Make(strings.TrimSuffix(path.Base(inFile), path.Ext(inFile))),
path.Ext(inFile),
}, "")
return executeTemplate(cfg.FileTemplate, map[string]interface{}{
"Ext": path.Ext(inFile),
"FileName": path.Base(inFile),
"Hash": fileHash,
"SafeFileName": safeFileName,
"UUID": uuid.Must(uuid.NewV4()).String(),
})
}
func hashFile(inFileHandle io.ReadSeeker) (hexHash string, err error) {
if _, err = inFileHandle.Seek(0, io.SeekStart); err != nil {
return "", fmt.Errorf("resetting reader: %w", err)
}
shaHash := sha256.New()
if _, err = io.Copy(shaHash, inFileHandle); err != nil {
return "", fmt.Errorf("reading data into hash: %w", err)
}
return fmt.Sprintf("%x", shaHash.Sum(nil)), nil
}
func executeTemplate(tplStr string, vars map[string]interface{}) (string, error) {
tpl, err := template.New("filename").Parse(tplStr)
if err != nil {
return "", fmt.Errorf("parsing filename template: %w", err)
}
buf := new(bytes.Buffer)
if err = tpl.Execute(buf, vars); err != nil {
return "", fmt.Errorf("executing filename template: %w", err)
}
return buf.String(), nil
}