accounting/pkg/database/database.go

582 lines
15 KiB
Go
Raw Permalink Normal View History

2024-01-15 22:49:29 +00:00
// Package database has a database client to access the transactions
// and account database together with helpers to interact with those
// tables
package database
import (
"errors"
"fmt"
2024-01-17 22:02:41 +00:00
"math"
2024-01-15 22:49:29 +00:00
"time"
"github.com/Luzifer/go_helpers/v2/backoff"
"github.com/glebarez/sqlite"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const dbMaxRetries = 5
type (
// Client is the database client
Client struct {
db *gorm.DB
}
)
// New creates a new database client for the given DSN
func New(dbtype, dsn string) (*Client, error) {
var conn gorm.Dialector
switch dbtype {
case "cockroach", "crdb", "postgres", "postgresql":
conn = postgres.Open(dsn)
case "sqlite", "sqlite3":
conn = sqlite.Open(dsn)
default:
return nil, fmt.Errorf("unknown db-type %s", dbtype)
}
db, err := gorm.Open(conn, &gorm.Config{
Logger: logger.New(loggerWriter{logrus.StandardLogger().WriterLevel(logrus.TraceLevel)}, logger.Config{
SlowThreshold: time.Second,
Colorful: false,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
LogLevel: logger.Info,
}),
})
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}
if err = db.AutoMigrate(
&Account{},
&Transaction{},
); err != nil {
return nil, fmt.Errorf("migrating database schema: %w", err)
}
2024-01-16 15:13:04 +00:00
for i := range migrateCreateAccounts {
a := migrateCreateAccounts[i]
if err = db.Save(&a).Error; err != nil {
return nil, fmt.Errorf("ensuring default account %q: %w", a.Name, err)
}
2024-01-15 22:49:29 +00:00
}
return &Client{
db: db,
}, nil
}
// CreateAccount creates and returns a new account of the given type
func (c *Client) CreateAccount(name string, accType AccountType) (a Account, err error) {
a = Account{
Name: name,
Type: accType,
}
if !accType.IsValid() {
return a, fmt.Errorf("invalid account type %s", accType)
}
if err = c.retryTx(func(db *gorm.DB) error {
return db.Save(&a).Error
}); err != nil {
return a, fmt.Errorf("creating account: %w", err)
}
return a, nil
}
// CreateTransaction takes a prepared transaction and stores it
func (c *Client) CreateTransaction(tx Transaction) (ntx Transaction, err error) {
if err = tx.Validate(c); err != nil {
return tx, fmt.Errorf("validating transaction: %w", err)
}
if err = c.retryTx(func(db *gorm.DB) error {
return db.Save(&tx).Error
}); err != nil {
return tx, fmt.Errorf("creating transaction: %w", err)
}
return tx, nil
}
// DeleteTransaction deletes a transaction
func (c *Client) DeleteTransaction(id uuid.UUID) (err error) {
if err = c.retryTx(func(db *gorm.DB) error {
tx, err := c.GetTransactionByID(id)
if err != nil {
return err
}
if tx.PairKey.Valid {
// We got a paired transaction which would be out-of-sync if we
// only delete one part of it so instead of doing a delete on the
// ID of the transaction, we do a delete on the pair-key
return db.Delete(&Transaction{}, "pair_key = ?", tx.PairKey.UUID).Error
}
2024-01-15 22:49:29 +00:00
return db.Delete(&Transaction{}, "id = ?", id).Error
}); err != nil {
return fmt.Errorf("deleting transaction: %w", err)
}
return nil
}
// GetAccount retrieves an Account using its ID
func (c *Client) GetAccount(id uuid.UUID) (a Account, err error) {
if err = c.retryRead(func(db *gorm.DB) error {
return db.First(&a, "id = ?", id).Error
}); err != nil {
return a, fmt.Errorf("fetching account: %w", err)
}
return a, nil
}
// GetTransactionByID returns a single transaction by its ID
func (c *Client) GetTransactionByID(id uuid.UUID) (tx Transaction, err error) {
if err = c.retryRead(func(db *gorm.DB) error {
return db.First(&tx, "id = ?", id).Error
}); err != nil {
return tx, fmt.Errorf("getting transaction: %w", err)
}
return tx, nil
}
// ListAccountBalances returns a list of accounts with their
// corresponding balance
2024-01-16 15:13:04 +00:00
func (c *Client) ListAccountBalances(showHidden bool) (a []AccountBalance, err error) {
accs, err := c.ListAccounts(showHidden)
2024-01-15 22:49:29 +00:00
if err != nil {
return nil, fmt.Errorf("listing accounts: %w", err)
}
for _, acc := range accs {
if err = c.retryRead(func(db *gorm.DB) error {
q := db.
Model(&Transaction{})
if acc.Type == AccountTypeCategory {
q = q.Where("category = ?", acc.ID)
} else {
q = q.Where("account = ?", acc.ID)
}
ab := AccountBalance{
Account: acc,
Balance: 0,
}
var v *float64
if err = q.
Select("sum(amount)").
Scan(&v).
Error; err != nil {
return fmt.Errorf("getting sum: %w", err)
}
if v != nil {
2024-01-17 22:02:41 +00:00
// Fix database doing e-15 stuff by rounding to full cents
ab.Balance = math.Round(*v*100) / 100 //nolint:mnd
2024-01-15 22:49:29 +00:00
}
a = append(a, ab)
return nil
}); err != nil {
return nil, fmt.Errorf("getting account balance for %s: %w", acc.ID, err)
}
}
return a, nil
}
// ListAccounts returns a list of all accounts
2024-01-16 15:13:04 +00:00
//
//revive:disable-next-line:flag-parameter
func (c *Client) ListAccounts(showHidden bool) (a []Account, err error) {
2024-01-15 22:49:29 +00:00
if err = c.retryRead(func(db *gorm.DB) error {
2024-01-16 15:13:04 +00:00
q := db.Model(&Account{})
if !showHidden {
q = q.Where("hidden = ?", false)
}
return q.Find(&a).Error
2024-01-15 22:49:29 +00:00
}); err != nil {
return a, fmt.Errorf("listing accounts: %w", err)
}
return a, nil
}
// ListAccountsByType returns a list of all accounts of the given type
2024-01-16 15:13:04 +00:00
//
//revive:disable-next-line:flag-parameter
func (c *Client) ListAccountsByType(at AccountType, showHidden bool) (a []Account, err error) {
2024-01-15 22:49:29 +00:00
if err = c.retryRead(func(db *gorm.DB) error {
2024-01-16 15:13:04 +00:00
q := db.Where("type = ?", at)
if !showHidden {
q = q.Where("hidden = ?", false)
}
return q.Find(&a).Error
2024-01-15 22:49:29 +00:00
}); err != nil {
return a, fmt.Errorf("listing accounts: %w", err)
}
return a, nil
}
2024-01-17 22:02:41 +00:00
// ListTransactions retrieves all transactions
func (c *Client) ListTransactions(since, until time.Time) (txs []Transaction, err error) {
if err = c.retryRead(func(db *gorm.DB) error {
return db.
Where("time >= ? and time <= ?", since, until).
Find(&txs).
Error
}); err != nil {
return txs, fmt.Errorf("listing transactions: %w", err)
}
return txs, nil
}
2024-01-15 22:49:29 +00:00
// ListTransactionsByAccount retrieves all transactions for an account
// or category
2024-01-16 16:30:17 +00:00
func (c *Client) ListTransactionsByAccount(acc uuid.UUID, since, until time.Time) (txs []Transaction, err error) {
2024-01-15 22:49:29 +00:00
if err = c.retryRead(func(db *gorm.DB) error {
return db.
2024-01-16 16:30:17 +00:00
Where("time >= ? and time <= ?", since, until).
2024-01-15 22:49:29 +00:00
Find(&txs, "account = ? OR category = ?", acc, acc).
Error
}); err != nil {
return txs, fmt.Errorf("listing transactions: %w", err)
}
return txs, nil
}
// MarkAccountReconciled marks all cleared transactions as reconciled.
// The account balance is NOT checked in this method.
func (c *Client) MarkAccountReconciled(acc uuid.UUID) (err error) {
if err = c.retryTx(func(db *gorm.DB) error {
return db.
Model(&Transaction{}).
Where("account = ?", acc).
Where("cleared = ?", true).
Update("reconciled", true).
Error
}); err != nil {
return fmt.Errorf("updating transactions: %w", err)
}
return nil
}
2024-01-15 22:49:29 +00:00
// TransferMoney creates new Transactions for the given account
// transfer. The account type of the from and to account must match
// for this to work.
func (c *Client) TransferMoney(from, to uuid.UUID, amount float64, description string) (err error) {
2024-01-15 22:49:29 +00:00
var fromAcc, toAcc Account
if fromAcc, err = c.GetAccount(from); err != nil {
return fmt.Errorf("getting source account: %w", err)
}
if toAcc, err = c.GetAccount(to); err != nil {
return fmt.Errorf("getting target account: %w", err)
}
if fromAcc.Type != toAcc.Type {
return fmt.Errorf("account type mismatch: %s != %s", fromAcc.Type, toAcc.Type)
}
pairKey := uuid.Must(uuid.NewRandom())
2024-01-15 22:49:29 +00:00
var txs []*Transaction
switch fromAcc.Type {
case AccountTypeBudget, AccountTypeTracking:
// Create TX with null-category
txs = []*Transaction{
{
Time: time.Now().UTC(),
Payee: fmt.Sprintf("Transfer: %s → %s", fromAcc.Name, toAcc.Name),
Description: description,
2024-01-15 22:49:29 +00:00
Amount: -amount,
Account: uuid.NullUUID{UUID: from, Valid: true},
Category: uuid.NullUUID{},
Cleared: false,
PairKey: uuid.NullUUID{UUID: pairKey, Valid: true},
2024-01-15 22:49:29 +00:00
},
{
Time: time.Now().UTC(),
Payee: fmt.Sprintf("Transfer: %s → %s", fromAcc.Name, toAcc.Name),
Description: description,
2024-01-15 22:49:29 +00:00
Amount: amount,
Account: uuid.NullUUID{UUID: to, Valid: true},
Category: uuid.NullUUID{},
Cleared: false,
PairKey: uuid.NullUUID{UUID: pairKey, Valid: true},
2024-01-15 22:49:29 +00:00
},
}
case AccountTypeCategory:
// Create TX with null-account
txs = []*Transaction{
{
Time: time.Now().UTC(),
Payee: fmt.Sprintf("Transfer: %s → %s", fromAcc.Name, toAcc.Name),
Description: description,
2024-01-15 22:49:29 +00:00
Amount: -amount,
Account: uuid.NullUUID{},
Category: uuid.NullUUID{UUID: from, Valid: true},
Cleared: true,
PairKey: uuid.NullUUID{UUID: pairKey, Valid: true},
2024-01-15 22:49:29 +00:00
},
{
Time: time.Now().UTC(),
Payee: fmt.Sprintf("Transfer: %s → %s", fromAcc.Name, toAcc.Name),
Description: description,
2024-01-15 22:49:29 +00:00
Amount: amount,
Account: uuid.NullUUID{},
Category: uuid.NullUUID{UUID: to, Valid: true},
Cleared: true,
PairKey: uuid.NullUUID{UUID: pairKey, Valid: true},
2024-01-15 22:49:29 +00:00
},
}
}
if err = c.retryTx(func(tx *gorm.DB) (err error) {
for _, t := range txs {
if err = tx.Save(t).Error; err != nil {
return fmt.Errorf("saving transaction: %w", err)
}
}
return nil
}); err != nil {
return fmt.Errorf("creating transactions: %w", err)
}
return nil
}
// TransferMoneyWithCategory creates new Transactions for the given
// account transfer. This is not possible for category type accounts.
func (c *Client) TransferMoneyWithCategory(from, to uuid.UUID, amount float64, description string, category uuid.UUID) (err error) {
2024-01-15 22:49:29 +00:00
var fromAcc, toAcc Account
if fromAcc, err = c.GetAccount(from); err != nil {
return fmt.Errorf("getting source account: %w", err)
}
if toAcc, err = c.GetAccount(to); err != nil {
return fmt.Errorf("getting target account: %w", err)
}
if fromAcc.Type == AccountTypeCategory || toAcc.Type == AccountTypeCategory {
return fmt.Errorf("transfer contained category-type account")
}
pairKey := uuid.Must(uuid.NewRandom())
2024-01-15 22:49:29 +00:00
if err = c.retryTx(func(tx *gorm.DB) (err error) {
fromTx := Transaction{
Time: time.Now().UTC(),
Payee: fmt.Sprintf("Transfer: %s → %s", fromAcc.Name, toAcc.Name),
Description: description,
2024-01-15 22:49:29 +00:00
Amount: -amount,
Account: uuid.NullUUID{UUID: from, Valid: true},
Category: uuid.NullUUID{},
Cleared: false,
PairKey: uuid.NullUUID{UUID: pairKey, Valid: true},
2024-01-15 22:49:29 +00:00
}
if fromAcc.Type == AccountTypeBudget {
fromTx.Category = uuid.NullUUID{UUID: category, Valid: true}
}
toTx := Transaction{
Time: time.Now().UTC(),
Payee: fmt.Sprintf("Transfer: %s → %s", fromAcc.Name, toAcc.Name),
Description: description,
2024-01-15 22:49:29 +00:00
Amount: amount,
Account: uuid.NullUUID{UUID: to, Valid: true},
Category: uuid.NullUUID{},
Cleared: false,
PairKey: uuid.NullUUID{UUID: pairKey, Valid: true},
2024-01-15 22:49:29 +00:00
}
if toAcc.Type == AccountTypeBudget {
toTx.Category = uuid.NullUUID{UUID: category, Valid: true}
}
for _, t := range []*Transaction{&fromTx, &toTx} {
if err = tx.Save(t).Error; err != nil {
return fmt.Errorf("saving transaction: %w", err)
}
}
return nil
}); err != nil {
return fmt.Errorf("creating transactions: %w", err)
}
return nil
}
// UpdateAccountHidden updates the hidden flag for the given Account
func (c *Client) UpdateAccountHidden(id uuid.UUID, hidden bool) (err error) {
if err = c.retryTx(func(db *gorm.DB) error {
return db.
Model(&Account{}).
Where("id = ?", id).
Update("hidden", hidden).
Error
}); err != nil {
return fmt.Errorf("updating account: %w", err)
}
return nil
}
// UpdateAccountName sets a new name for the given account ID
func (c *Client) UpdateAccountName(id uuid.UUID, name string) (err error) {
if err = c.retryTx(func(db *gorm.DB) error {
return db.
Model(&Account{}).
Where("id = ?", id).
Update("name", name).
Error
}); err != nil {
return fmt.Errorf("updating account: %w", err)
}
return nil
}
// UpdateTransaction takes a transaction, fetches the stored transaction
// applies some sanity actions and stores it back to the database
func (c *Client) UpdateTransaction(txID uuid.UUID, tx Transaction) (err error) {
if err = c.retryTx(func(db *gorm.DB) error {
var oldTX Transaction
if err := db.First(&oldTX, "id = ?", txID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(fmt.Errorf("fetching old transaction: %w", err))
}
return fmt.Errorf("fetching old transaction: %w", err)
}
tx.ID = txID
tx.Account = oldTX.Account // Changing that would create chaos
tx.PairKey = oldTX.PairKey // Updating a paired tx should not decouple it
if err = tx.Validate(c); err != nil {
return fmt.Errorf("validating transaction: %w", err)
}
if err = db.Save(&tx).Error; err != nil {
return fmt.Errorf("saving transaction: %w", err)
}
if !oldTX.PairKey.Valid || tx.Amount == oldTX.Amount {
// is not a paired transaction or amount did not change: skip rest
return nil
}
// transaction is paired and amount changed, we need to update the
// paired transaction too or it will cause trouble
if err = db.Model(&Transaction{}).
Where("pair_key = ?", oldTX.PairKey.UUID).
Where("id <> ?", oldTX.ID).
Update("amount", -tx.Amount).
Error; err != nil {
return fmt.Errorf("updating amount for paired transaction: %w", err)
}
return nil
}); err != nil {
return fmt.Errorf("updating transaction: %w", err)
}
return nil
}
2024-01-15 22:49:29 +00:00
// UpdateTransactionCategory modifies the category of the given
// transaction. (It is not possible to remove a category with this)
func (c *Client) UpdateTransactionCategory(id uuid.UUID, cat uuid.UUID) (err error) {
if err = c.retryTx(func(db *gorm.DB) error {
var tx Transaction
if err = db.First(&tx, "id = ?", id).Error; err != nil {
return fmt.Errorf("fetching transaction: %w", err)
}
tx.Category = uuid.NullUUID{UUID: cat, Valid: true}
if err = tx.Validate(c); err != nil {
return fmt.Errorf("validating transaction: %w", err)
}
if err = db.
Save(&tx).
Error; err != nil {
return fmt.Errorf("saving transaction: %w", err)
}
return nil
}); err != nil {
return fmt.Errorf("updating transaction: %w", err)
}
return nil
}
// UpdateTransactionCleared modifies the "cleared" flag for the given
// transaction
func (c *Client) UpdateTransactionCleared(id uuid.UUID, cleared bool) (err error) {
if err = c.retryTx(func(db *gorm.DB) error {
return db.
Model(&Transaction{}).
Where("id = ?", id).
Update("cleared", cleared).
Error
}); err != nil {
return fmt.Errorf("updating transaction: %w", err)
}
return nil
}
func (c *Client) retryRead(fn func(db *gorm.DB) error) error {
//nolint:wrapcheck
return backoff.NewBackoff().
WithMaxIterations(dbMaxRetries).
Retry(func() error {
err := fn(c.db)
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err)
}
return err
})
}
func (c *Client) retryTx(fn func(db *gorm.DB) error) error {
//nolint:wrapcheck
return backoff.NewBackoff().
WithMaxIterations(dbMaxRetries).
Retry(func() error {
return c.db.Transaction(fn)
})
}