mirror of
https://github.com/Luzifer/twitch-bot.git
synced 2024-11-09 00:30:02 +00:00
142 lines
3.2 KiB
Go
142 lines
3.2 KiB
Go
package punish
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
//go:embed schema/**
|
|
var schema embed.FS
|
|
|
|
func calculateCurrentPunishments() error {
|
|
rows, err := db.DB().Query(
|
|
`SELECT key, last_level, executed, cooldown
|
|
FROM punish_levels;`,
|
|
)
|
|
if err != nil {
|
|
return errors.Wrap(err, "querying punish_levels")
|
|
}
|
|
|
|
for rows.Next() {
|
|
if err = rows.Err(); err != nil {
|
|
return errors.Wrap(err, "advancing rows")
|
|
}
|
|
|
|
var (
|
|
key string
|
|
lastLevel, executed, cooldown int64
|
|
|
|
actUpdate bool
|
|
)
|
|
if err = rows.Scan(&key, &lastLevel, &executed, &cooldown); err != nil {
|
|
return errors.Wrap(err, "advancing rows")
|
|
}
|
|
|
|
lvl := &levelConfig{
|
|
LastLevel: int(lastLevel),
|
|
Cooldown: time.Duration(cooldown),
|
|
Executed: time.Unix(executed, 0),
|
|
}
|
|
|
|
for {
|
|
cooldownTime := lvl.Executed.Add(lvl.Cooldown)
|
|
if cooldownTime.After(time.Now()) {
|
|
break
|
|
}
|
|
|
|
lvl.Executed = cooldownTime
|
|
lvl.LastLevel--
|
|
actUpdate = true
|
|
}
|
|
|
|
// Level 0 is the first punishment level, so only remove if it drops below 0
|
|
if lvl.LastLevel < 0 {
|
|
if err = deletePunishmentForKey(key); err != nil {
|
|
return errors.Wrap(err, "cleaning up expired punishment")
|
|
}
|
|
continue
|
|
}
|
|
|
|
if actUpdate {
|
|
if err = setPunishmentForKey(key, lvl); err != nil {
|
|
return errors.Wrap(err, "updating punishment")
|
|
}
|
|
}
|
|
}
|
|
|
|
return errors.Wrap(rows.Err(), "finishing rows processing")
|
|
}
|
|
|
|
func deletePunishment(channel, user, uuid string) error {
|
|
return deletePunishmentForKey(getDBKey(channel, user, uuid))
|
|
}
|
|
|
|
func deletePunishmentForKey(key string) error {
|
|
_, err := db.DB().Exec(
|
|
`DELETE FROM punish_levels
|
|
WHERE key = $1;`,
|
|
key,
|
|
)
|
|
|
|
return errors.Wrap(err, "deleting punishment info")
|
|
}
|
|
|
|
func getPunishment(channel, user, uuid string) (*levelConfig, error) {
|
|
if err := calculateCurrentPunishments(); err != nil {
|
|
return nil, errors.Wrap(err, "updating punishment states")
|
|
}
|
|
|
|
row := db.DB().QueryRow(
|
|
`SELECT last_level, executed, cooldown
|
|
FROM punish_levels
|
|
WHERE key = $1;`,
|
|
getDBKey(channel, user, uuid),
|
|
)
|
|
|
|
lc := &levelConfig{LastLevel: -1}
|
|
|
|
var lastLevel, executed, cooldown int64
|
|
err := row.Scan(&lastLevel, &executed, &cooldown)
|
|
switch {
|
|
case err == nil:
|
|
lc.LastLevel = int(lastLevel)
|
|
lc.Cooldown = time.Duration(cooldown)
|
|
lc.Executed = time.Unix(executed, 0)
|
|
|
|
return lc, nil
|
|
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return lc, nil
|
|
|
|
default:
|
|
return nil, errors.Wrap(err, "getting punishment from database")
|
|
}
|
|
}
|
|
|
|
func setPunishment(channel, user, uuid string, lc *levelConfig) error {
|
|
return setPunishmentForKey(getDBKey(channel, user, uuid), lc)
|
|
}
|
|
|
|
func setPunishmentForKey(key string, lc *levelConfig) error {
|
|
_, err := db.DB().Exec(
|
|
`INSERT INTO punish_levels
|
|
(key, last_level, executed, cooldown)
|
|
VALUES ($1, $2, $3, $4)
|
|
ON CONFLICT DO UPDATE
|
|
SET last_level = excluded.last_level,
|
|
executed = excluded.executed,
|
|
cooldown = excluded.cooldown;`,
|
|
key,
|
|
lc.LastLevel, lc.Executed.UTC().Unix(), int64(lc.Cooldown),
|
|
)
|
|
|
|
return errors.Wrap(err, "updating punishment info")
|
|
}
|
|
|
|
func getDBKey(channel, user, uuid string) string {
|
|
return strings.Join([]string{channel, user, uuid}, "::")
|
|
}
|