diff --git a/providers/s3/provider.go b/providers/s3/provider.go index 7918f16..2b8502a 100644 --- a/providers/s3/provider.go +++ b/providers/s3/provider.go @@ -2,7 +2,9 @@ package s3 import ( "bytes" + "context" "crypto/md5" + "fmt" "hash" "io" "net/url" @@ -12,14 +14,16 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/pkg/errors" "github.com/Luzifer/cloudbox/providers" ) type Provider struct { - bucket string - s3 *s3.S3 + bucket string + bucketRegion string + s3 *s3.S3 } func New(uri string) (providers.CloudProvider, error) { @@ -39,11 +43,18 @@ func New(uri string) (providers.CloudProvider, error) { cfg = cfg.WithCredentials(credentials.NewStaticCredentials(user, pass, "")) } - svc := s3.New(session.Must(session.NewSession(cfg))) + sess := session.Must(session.NewSession(cfg)) + svc := s3.New(sess) + + region, err := s3manager.GetBucketRegion(context.Background(), sess, u.Host, "us-east-1") + if err != nil { + return nil, errors.Wrap(err, "Unable to find bucket region") + } return &Provider{ - bucket: u.Host, - s3: svc, + bucket: u.Host, + bucketRegion: region, + s3: svc, }, nil } @@ -119,6 +130,7 @@ func (p *Provider) PutFile(f providers.File) (providers.File, error) { } if _, err = p.s3.PutObject(&s3.PutObjectInput{ + ACL: aws.String(p.getFileACL(f.Info().RelativeName)), Body: bytes.NewReader(buf.Bytes()), Bucket: aws.String(p.bucket), Key: aws.String(f.Info().RelativeName), @@ -130,5 +142,33 @@ func (p *Provider) PutFile(f providers.File) (providers.File, error) { } func (p *Provider) Share(relativeName string) (string, error) { - return "", errors.New("Not implemented") + _, err := p.s3.PutObjectAcl(&s3.PutObjectAclInput{ + ACL: aws.String(s3.ObjectCannedACLPublicRead), + Bucket: aws.String(p.bucket), + Key: aws.String(relativeName), + }) + if err != nil { + return "", errors.Wrap(err, "Unable to publish file") + } + + return fmt.Sprintf("https://s3-%s.amazonaws.com/%s/%s", p.bucketRegion, p.bucket, relativeName), nil +} + +func (p *Provider) getFileACL(relativeName string) string { + objACL, err := p.s3.GetObjectAcl(&s3.GetObjectAclInput{ + Bucket: aws.String(p.bucket), + Key: aws.String(relativeName), + }) + + if err != nil { + return s3.ObjectCannedACLPrivate + } + + for _, g := range objACL.Grants { + if *g.Grantee.URI == "http://acs.amazonaws.com/groups/global/AllUsers" && *g.Permission == "READ" { + return s3.ObjectCannedACLPublicRead + } + } + + return s3.ObjectCannedACLPrivate }