package timer

import (
	"crypto/sha256"
	"fmt"
	"strings"
	"time"

	"github.com/pkg/errors"
	"github.com/robfig/cron/v3"
	"github.com/sirupsen/logrus"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"

	"github.com/Luzifer/go_helpers/v2/backoff"
	"github.com/Luzifer/twitch-bot/v3/internal/helpers"
	"github.com/Luzifer/twitch-bot/v3/pkg/database"
	"github.com/Luzifer/twitch-bot/v3/plugins"
)

type (
	Service struct {
		db            database.Connector
		permitTimeout time.Duration
	}

	timer struct {
		ID        string `gorm:"primaryKey"`
		ExpiresAt time.Time
	}
)

var _ plugins.TimerStore = (*Service)(nil)

func New(db database.Connector, cronService *cron.Cron) (*Service, error) {
	s := &Service{
		db: db,
	}

	if cronService != nil {
		if _, err := cronService.AddFunc("@every 5m", s.cleanupTimers); err != nil {
			return nil, errors.Wrap(err, "registering timer cleanup cron")
		}
	}

	return s, errors.Wrap(s.db.DB().AutoMigrate(&timer{}), "applying migrations")
}

func (s *Service) CopyDatabase(src, target *gorm.DB) error {
	return database.CopyObjects(src, target, &timer{})
}

func (s *Service) UpdatePermitTimeout(d time.Duration) {
	s.permitTimeout = d
}

// Cooldown timer

func (s Service) AddCooldown(tt plugins.TimerType, limiter, ruleID string, expiry time.Time) error {
	return s.SetTimer(s.getCooldownTimerKey(tt, limiter, ruleID), expiry)
}

func (s Service) InCooldown(tt plugins.TimerType, limiter, ruleID string) (bool, error) {
	return s.HasTimer(s.getCooldownTimerKey(tt, limiter, ruleID))
}

func (Service) getCooldownTimerKey(tt plugins.TimerType, limiter, ruleID string) string {
	h := sha256.New()
	fmt.Fprintf(h, "%d:%s:%s", tt, limiter, ruleID)
	return fmt.Sprintf("sha256:%x", h.Sum(nil))
}

// Permit timer

func (s Service) AddPermit(channel, username string) error {
	return s.SetTimer(s.getPermitTimerKey(channel, username), time.Now().Add(s.permitTimeout))
}

func (s Service) HasPermit(channel, username string) (bool, error) {
	return s.HasTimer(s.getPermitTimerKey(channel, username))
}

func (Service) getPermitTimerKey(channel, username string) string {
	h := sha256.New()
	fmt.Fprintf(h, "%d:%s:%s", plugins.TimerTypePermit, channel, strings.ToLower(strings.TrimLeft(username, "@")))
	return fmt.Sprintf("sha256:%x", h.Sum(nil))
}

// Generic timer

func (s Service) HasTimer(id string) (bool, error) {
	var t timer
	err := helpers.Retry(func() error {
		err := s.db.DB().First(&t, "id = ? AND expires_at >= ?", id, time.Now().UTC()).Error
		if errors.Is(err, gorm.ErrRecordNotFound) {
			return backoff.NewErrCannotRetry(err)
		}
		return err
	})
	switch {
	case err == nil:
		return true, nil

	case errors.Is(err, gorm.ErrRecordNotFound):
		return false, nil

	default:
		return false, errors.Wrap(err, "getting timer information")
	}
}

func (s Service) SetTimer(id string, expiry time.Time) error {
	return errors.Wrap(
		helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {
			return tx.Clauses(clause.OnConflict{
				Columns:   []clause.Column{{Name: "id"}},
				DoUpdates: clause.AssignmentColumns([]string{"expires_at"}),
			}).Create(timer{
				ID:        id,
				ExpiresAt: expiry.UTC(),
			}).Error
		}),
		"storing counter in database",
	)
}

func (s Service) cleanupTimers() {
	if err := helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {
		return tx.Delete(&timer{}, "expires_at < ?", time.Now().UTC()).Error
	}); err != nil {
		logrus.WithError(err).Error("cleaning up expired timers")
	}
}