package sqs

import (
	"crypto/md5"
	"encoding/hex"
	"fmt"
	"strings"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/request"
)

var (
	errChecksumMissingBody = fmt.Errorf("cannot compute checksum. missing body")
	errChecksumMissingMD5  = fmt.Errorf("cannot verify checksum. missing response MD5")
)

func setupChecksumValidation(r *request.Request) {
	if aws.BoolValue(r.Config.DisableComputeChecksums) {
		return
	}

	switch r.Operation.Name {
	case opSendMessage:
		r.Handlers.Unmarshal.PushBack(verifySendMessage)
	case opSendMessageBatch:
		r.Handlers.Unmarshal.PushBack(verifySendMessageBatch)
	case opReceiveMessage:
		r.Handlers.Unmarshal.PushBack(verifyReceiveMessage)
	}
}

func verifySendMessage(r *request.Request) {
	if r.DataFilled() && r.ParamsFilled() {
		in := r.Params.(*SendMessageInput)
		out := r.Data.(*SendMessageOutput)
		err := checksumsMatch(in.MessageBody, out.MD5OfMessageBody)
		if err != nil {
			setChecksumError(r, err.Error())
		}
	}
}

func verifySendMessageBatch(r *request.Request) {
	if r.DataFilled() && r.ParamsFilled() {
		entries := map[string]*SendMessageBatchResultEntry{}
		ids := []string{}

		out := r.Data.(*SendMessageBatchOutput)
		for _, entry := range out.Successful {
			entries[*entry.Id] = entry
		}

		in := r.Params.(*SendMessageBatchInput)
		for _, entry := range in.Entries {
			if e := entries[*entry.Id]; e != nil {
				err := checksumsMatch(entry.MessageBody, e.MD5OfMessageBody)
				if err != nil {
					ids = append(ids, *e.MessageId)
				}
			}
		}
		if len(ids) > 0 {
			setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
		}
	}
}

func verifyReceiveMessage(r *request.Request) {
	if r.DataFilled() && r.ParamsFilled() {
		ids := []string{}
		out := r.Data.(*ReceiveMessageOutput)
		for i, msg := range out.Messages {
			err := checksumsMatch(msg.Body, msg.MD5OfBody)
			if err != nil {
				if msg.MessageId == nil {
					if r.Config.Logger != nil {
						r.Config.Logger.Log(fmt.Sprintf(
							"WARN: SQS.ReceiveMessage failed checksum request id: %s, message %d has no message ID.",
							r.RequestID, i,
						))
					}
					continue
				}

				ids = append(ids, *msg.MessageId)
			}
		}
		if len(ids) > 0 {
			setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
		}
	}
}

func checksumsMatch(body, expectedMD5 *string) error {
	if body == nil {
		return errChecksumMissingBody
	} else if expectedMD5 == nil {
		return errChecksumMissingMD5
	}

	msum := md5.Sum([]byte(*body))
	sum := hex.EncodeToString(msum[:])
	if sum != *expectedMD5 {
		return fmt.Errorf("expected MD5 checksum '%s', got '%s'", *expectedMD5, sum)
	}

	return nil
}

func setChecksumError(r *request.Request, format string, args ...interface{}) {
	r.Retryable = aws.Bool(true)
	r.Error = awserr.New("InvalidChecksum", fmt.Sprintf(format, args...), nil)
}