commit 86b2b4943e48ba411ea93b436a7e4e8de790e0b3 Author: Knut Ahlers Date: Tue Nov 14 00:40:07 2023 +0100 Initial version diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1c65a77 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +config.yaml +ipt-loadbalancer diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6ac9584 --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module git.luzifer.io/luzifer/ipt-loadbalancer + +go 1.21.3 + +require ( + github.com/Luzifer/go_helpers/v2 v2.21.0 + github.com/Luzifer/rconfig/v2 v2.4.0 + github.com/coreos/go-iptables v0.7.0 + github.com/pkg/errors v0.9.1 + github.com/sirupsen/logrus v1.9.3 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/kr/pretty v0.3.1 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/sys v0.13.0 // indirect + gopkg.in/validator.v2 v2.0.0-20210331031555-b37d688a7fb0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..013c2e6 --- /dev/null +++ b/go.sum @@ -0,0 +1,44 @@ +github.com/Luzifer/go_helpers/v2 v2.21.0 h1:kR0kdpTkYpkou3qOr2E+sXh0FxG85Mof4BlRhfSB790= +github.com/Luzifer/go_helpers/v2 v2.21.0/go.mod h1:cIIqMPu3NT8/6kHke+03hVznNDLLKVGA74Lz47CWJyA= +github.com/Luzifer/rconfig/v2 v2.4.0 h1:MAdymTlExAZ8mx5VG8xOFAtFQSpWBipKYQHPOmYTn9o= +github.com/Luzifer/rconfig/v2 v2.4.0/go.mod h1:hWF3ZVSusbYlg5bEvCwalEyUSY+0JPJWUiIu7rBmav8= +github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= +github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/validator.v2 v2.0.0-20210331031555-b37d688a7fb0 h1:EFLtLCwd8tGN+r/ePz3cvRtdsfYNhDEdt/vp6qsT+0A= +gopkg.in/validator.v2 v2.0.0-20210331031555-b37d688a7fb0/go.mod h1:o4V0GXN9/CAmCsvJ0oXYZvrZOe7syiDZSN1GWGZTGzc= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..8eb6e77 --- /dev/null +++ b/main.go @@ -0,0 +1,92 @@ +package main + +import ( + "os" + + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/config" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/iptables" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/servicemonitor" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + + "github.com/Luzifer/rconfig/v2" +) + +var ( + cfg = struct { + Config string `flag:"config,c" default:"config.yaml" description:"Configuration file to load"` + InsertIntoPrerouting bool `flag:"insert-into-prerouting,i" default:"false" description:"Modify PREROUTING chain to contain a jump to managed chain"` + LogLevel string `flag:"log-level" default:"info" description:"Log level (debug, info, warn, error, fatal)"` + VersionAndExit bool `flag:"version" default:"false" description:"Prints current version and exits"` + }{} + + version = "dev" +) + +func initApp() error { + rconfig.AutoEnv(true) + if err := rconfig.ParseAndValidate(&cfg); err != nil { + return errors.Wrap(err, "parsing cli options") + } + + l, err := logrus.ParseLevel(cfg.LogLevel) + if err != nil { + return errors.Wrap(err, "parsing log-level") + } + logrus.SetLevel(l) + + return nil +} + +func main() { + var err error + if err = initApp(); err != nil { + logrus.WithError(err).Fatal("initializing app") + } + + if cfg.VersionAndExit { + logrus.WithField("version", version).Info("ipt-loadbalancer") + os.Exit(0) + } + + confFile, err := config.Load(cfg.Config) + if err != nil { + logrus.WithError(err).Fatal("loading config file") + } + + ipt, err := iptables.New(confFile.ManagedChain) + if err != nil { + logrus.WithError(err).Fatal("creating iptables client") + } + + if err = ipt.EnsureManagedChains(); err != nil { + logrus.WithError(err).Fatal("creating managed chain") + } + + if cfg.InsertIntoPrerouting { + if err = ipt.EnableMangedRoutingChains(); err != nil { + logrus.WithError(err).Fatal("enabling routing") + } + } + + svcErr := make(chan error, 1) + for i := range confFile.Services { + s := confFile.Services[i] + + sMon := servicemonitor.New(ipt, logrus.WithField("service", s.Name), s) + go func() { svcErr <- sMon.Run() }() + } + + logrus.WithFields(logrus.Fields{ + "services": len(confFile.Services), + "version": version, + }).Info("ipt-loadbalancer started") + + for err := range svcErr { + if err == nil { + continue + } + + logrus.WithError(err).Fatal("service monitor caused error") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..d84643b --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,84 @@ +// Package config defines the syntax of the configuration file +package config + +import ( + "bytes" + _ "embed" + "fmt" + "os" + "time" + + "github.com/Luzifer/go_helpers/v2/fieldcollection" + "gopkg.in/yaml.v3" +) + +type ( + // File wraps the whole config file content + File struct { + ManagedChain string `yaml:"managedChain"` + Services []Service `yaml:"services"` + } + + // Service represents a single service to be exposed + Service struct { + Name string `yaml:"name"` + HealthCheck ServiceHealthCheck `yaml:"healthCheck"` + LocalAddr string `yaml:"localAddr"` + LocalPort int `yaml:"localPort"` + LocalProto string `yaml:"localProto"` + Targets []Target `yaml:"targets"` + } + + // ServiceHealthCheck defines type and settings for the health- + // check to apply to the targets to deem them alive + ServiceHealthCheck struct { + Type string `yaml:"type"` + Interval time.Duration `yaml:"interval"` + Settings *fieldcollection.FieldCollection `yaml:"settings"` + } + + // Target represents a load-balancing target to route the traffic + // to in case it is deemed alive + Target struct { + Addr string `yaml:"addr"` + Port int `yaml:"port"` + Weight int `yaml:"weight"` + } +) + +//go:embed default.yaml +var defaultConfig []byte + +// Load reads the configuration file from disk and parses it over the +// included default configuration +func Load(fn string) (cf File, err error) { + defConf := yaml.NewDecoder(bytes.NewReader(defaultConfig)) + defConf.KnownFields(true) + if err = defConf.Decode(&cf); err != nil { + return cf, fmt.Errorf("unmarshalling default config: %w", err) + } + + f, err := os.Open(fn) //#nosec:G304 // This is intended to load a custom config file + if err != nil { + return cf, fmt.Errorf("opening config file: %w", err) + } + defer f.Close() //nolint:errcheck + + fileConf := yaml.NewDecoder(f) + fileConf.KnownFields(true) + if err = fileConf.Decode(&cf); err != nil { + return cf, fmt.Errorf("unmarshalling config file: %w", err) + } + + return cf, nil +} + +// Proto evaluates the LocalProto and returns tcp if empty +func (s Service) Proto() string { + if s.LocalProto == "" { + return "tcp" + } + return s.LocalProto +} + +func (t Target) String() string { return fmt.Sprintf("%s:%d", t.Addr, t.Port) } diff --git a/pkg/config/default.yaml b/pkg/config/default.yaml new file mode 100644 index 0000000..044fb53 --- /dev/null +++ b/pkg/config/default.yaml @@ -0,0 +1,6 @@ +--- + +managedChain: IPTLB +services: [] + +... diff --git a/pkg/healthcheck/common/common.go b/pkg/healthcheck/common/common.go new file mode 100644 index 0000000..a9f3f2a --- /dev/null +++ b/pkg/healthcheck/common/common.go @@ -0,0 +1,11 @@ +// Package common contains some helpers used in multiple checks +package common + +type ( + // SettingHelp is used to render a help for check config + SettingHelp struct { + Name string + Default any + Description string + } +) diff --git a/pkg/healthcheck/http/http.go b/pkg/healthcheck/http/http.go new file mode 100644 index 0000000..2c3ac47 --- /dev/null +++ b/pkg/healthcheck/http/http.go @@ -0,0 +1,128 @@ +// Package http contains a http health-check +package http + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/config" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/healthcheck/common" + "github.com/Luzifer/go_helpers/v2/fieldcollection" +) + +const ( + settingCode = "code" + settingExpectContent = "expectContent" + settingHost = "host" + settingInsecureTLS = "insecureTLS" + settingMethod = "method" + settingPath = "path" + settingPort = "port" + settingTimeout = "timeout" + settingTLS = "tls" +) + +type ( + // Check represents the HTTP check + Check struct{} +) + +var ( + defCode = http.StatusOK + defExpectContent = "" + defHost = "" + defInsecureTLS = false + defMethod = http.MethodGet + defPath = "/" + defTimeout = time.Second + defTLS = false +) + +// New returns a new HTTP check +func New() Check { return Check{} } + +// Check executes the check +func (c Check) Check(settings *fieldcollection.FieldCollection, target config.Target) error { + ctx, cancel := context.WithTimeout(context.Background(), settings.MustDuration(settingTimeout, &defTimeout)) + defer cancel() + + u := url.URL{ + Scheme: "http", + Host: fmt.Sprintf("%s:%d", target.Addr, settings.MustInt64(settingPort, c.intToInt64Ptr(target.Port))), + Path: settings.MustString(settingPath, &defPath), + } + + if settings.MustBool(settingTLS, &defTLS) { + u.Scheme = "https" + } + + req, err := http.NewRequestWithContext(ctx, settings.MustString(settingMethod, &defMethod), u.String(), nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + req.Header.Set("User-Agent", "ipt-loadbalancer/v1 (https://git.luzifer.io/luzifer/ipt-loadbalancer)") + + if hh := settings.MustString(settingHost, &defHost); hh != defHost { + req.Header.Set("Host", hh) + } + + client := http.Client{} + if settings.MustBool(settingInsecureTLS, &defInsecureTLS) { + client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec // The intention is to use insecure TLS + }, + } + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("executing request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != int(settings.MustInt64(settingCode, c.intToInt64Ptr(defCode))) { + return fmt.Errorf("unexpected status code %d != %d", resp.StatusCode, settings.MustInt64(settingCode, c.intToInt64Ptr(defCode))) + } + + if settings.MustString(settingExpectContent, &defExpectContent) == defExpectContent { + return nil + } + + content, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("reading response body: %w", err) + } + + if !strings.Contains(string(content), settings.MustString(settingExpectContent, &defExpectContent)) { + return fmt.Errorf("expected content not found in body") + } + + return nil +} + +// Help returns the set of settings used in the check +func (Check) Help() (help []common.SettingHelp) { + return []common.SettingHelp{ + {Name: settingCode, Default: defCode, Description: "HTTP Status-Code to expect from the request"}, + {Name: settingExpectContent, Default: defExpectContent, Description: "Content to search in the response body"}, + {Name: settingHost, Default: defHost, Description: "Host header to send with the request"}, + {Name: settingInsecureTLS, Default: defInsecureTLS, Description: "Skip TLS certificate validation"}, + {Name: settingMethod, Default: defMethod, Description: "Method to use for request"}, + {Name: settingPath, Default: defPath, Description: "Path to send the request to"}, + {Name: settingPort, Default: "target-port", Description: "Port to send the request to"}, + {Name: settingTimeout, Default: defTimeout, Description: "Timeout for the HTTP request"}, + {Name: settingTLS, Default: defTLS, Description: "Connect to port using TLS"}, + } +} + +func (Check) intToInt64Ptr(i int) *int64 { + i64 := int64(i) + return &i64 +} diff --git a/pkg/healthcheck/registry.go b/pkg/healthcheck/registry.go new file mode 100644 index 0000000..5abef1b --- /dev/null +++ b/pkg/healthcheck/registry.go @@ -0,0 +1,34 @@ +// Package healthcheck contains the interface checks have to implement +// and a registry to get them by name +package healthcheck + +import ( + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/config" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/healthcheck/common" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/healthcheck/http" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/healthcheck/tcp" + "github.com/Luzifer/go_helpers/v2/fieldcollection" +) + +type ( + // Checker defines the interface a healthcheck must support + Checker interface { + Check(settings *fieldcollection.FieldCollection, target config.Target) error + Help() []common.SettingHelp + } +) + +// ByName returns the Checker for the given name or nil if that name +// is not registered +func ByName(name string) Checker { + switch name { + case "http": + return http.New() + + case "tcp": + return tcp.New() + + default: + return nil + } +} diff --git a/pkg/healthcheck/tcp/tcp.go b/pkg/healthcheck/tcp/tcp.go new file mode 100644 index 0000000..5f52edd --- /dev/null +++ b/pkg/healthcheck/tcp/tcp.go @@ -0,0 +1,58 @@ +// Package tcp implements a simple TCP health-check +package tcp + +import ( + "fmt" + "net" + "time" + + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/config" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/healthcheck/common" + "github.com/Luzifer/go_helpers/v2/fieldcollection" +) + +const ( + settingPort = "port" + settingTimeout = "timeout" +) + +type ( + // Check represents the TCP check + Check struct{} +) + +var defTimeout = time.Second + +// New returns a new TCP check +func New() Check { return Check{} } + +// Check executes the check +func (c Check) Check(settings *fieldcollection.FieldCollection, target config.Target) error { + conn, err := net.DialTimeout( + "tcp", + fmt.Sprintf("%s:%d", target.Addr, settings.MustInt64(settingPort, c.intToInt64Ptr(target.Port))), + settings.MustDuration(settingTimeout, &defTimeout), + ) + if err != nil { + return fmt.Errorf("dialing tcp: %w", err) + } + + if err = conn.Close(); err != nil { + return fmt.Errorf("closing connection: %w", err) + } + + return nil +} + +// Help returns the set of settings used in the check +func (Check) Help() (help []common.SettingHelp) { + return []common.SettingHelp{ + {Name: settingPort, Default: "target-port", Description: "Port to send the request to"}, + {Name: settingTimeout, Default: defTimeout, Description: "Timeout for the connect"}, + } +} + +func (Check) intToInt64Ptr(i int) *int64 { + i64 := int64(i) + return &i64 +} diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go new file mode 100644 index 0000000..74c9715 --- /dev/null +++ b/pkg/iptables/iptables.go @@ -0,0 +1,250 @@ +// Package iptables contains the logic to interact with the iptables +// system interface +package iptables + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "sync" + + coreosIptables "github.com/coreos/go-iptables/iptables" + "github.com/mitchellh/hashstructure/v2" +) + +const ( + natTable = "nat" + probBitsize = 64 + probPrecision = 3 +) + +type ( + // Client contains the required functions to create the loadbalancing + Client struct { + *coreosIptables.IPTables + + managedChain string + + lock sync.RWMutex + services map[string][]NATTarget + } + + // NATTarget contains the configuration for a DNAT jump target + // with random distribution and given probability + NATTarget struct { + Addr string + LocalAddr string + LocalPort int + Port int + Proto string + Weight float64 + } + + // ServiceChain contains the name of the chain and a definition + // which IP/Port combination should be sent to that chain + ServiceChain struct { + Name string + Addr string + Port int + Proto string + } + + chainType uint +) + +const ( + chainTypeDNAT chainType = iota + chainTypeSNAT +) + +var disallowedChars = regexp.MustCompile(`[^A-Z0-9_]`) + +// New creates a new IPTables client +func New(managedChain string) (c *Client, err error) { + c = &Client{ + managedChain: managedChain, + + services: make(map[string][]NATTarget), + } + if c.IPTables, err = coreosIptables.New(); err != nil { + return nil, fmt.Errorf("creating iptables client: %w", err) + } + + return c, nil +} + +// EnsureManagedChains creates the managed chain referring to the +// service chains while only leading the specified address / port +// to that service chain +func (c *Client) EnsureManagedChains() (err error) { + c.lock.RLock() + defer c.lock.RUnlock() + + var ( + dnat [][]string + snat [][]string + ) + + for s := range c.services { + for chain, cType := range map[string]chainType{ + c.tableName(c.managedChain, s, "DNAT"): chainTypeDNAT, + c.tableName(c.managedChain, s, "SNAT"): chainTypeSNAT, + } { + if err = c.ensureChainWithRules(chain, c.buildServiceTable(s, cType)); err != nil { + return fmt.Errorf("creating chain %q: %w", chain, err) + } + } + + dnat = append(dnat, []string{"-j", c.tableName(c.managedChain, s, "DNAT")}) + snat = append(snat, []string{"-j", c.tableName(c.managedChain, s, "SNAT")}) + } + + dnat = append(dnat, []string{"-j", "RETURN"}) + snat = append(snat, []string{"-j", "RETURN"}) + + if err = c.ensureChainWithRules(c.tableName(c.managedChain, "DNAT"), dnat); err != nil { + return fmt.Errorf("creating managed DNAT chain: %w", err) + } + + if err = c.ensureChainWithRules(c.tableName(c.managedChain, "SNAT"), snat); err != nil { + return fmt.Errorf("creating managed SNAT chain: %w", err) + } + + return nil +} + +// EnableMangedRoutingChains inserts a jump to the given managed chains +// at position 1 of the PREROUTING and POSTROUTING chains if it does +// not already exist in the chain +func (c *Client) EnableMangedRoutingChains() (err error) { + if err = c.InsertUnique(natTable, "PREROUTING", 1, "-j", c.tableName(c.managedChain, "DNAT")); err != nil { + return fmt.Errorf("ensuring DNAT jump to managed chain: %w", err) + } + + if err = c.InsertUnique(natTable, "POSTROUTING", 1, "-j", c.tableName(c.managedChain, "SNAT")); err != nil { + return fmt.Errorf("ensuring SNAT jump to managed chain: %w", err) + } + + return nil +} + +// RegisterServiceTarget adds a new routing target to the given service +func (c *Client) RegisterServiceTarget(service string, t NATTarget) bool { + c.lock.Lock() + defer c.lock.Unlock() + + var found bool + for _, et := range c.services[service] { + found = found || et.equals(t) + } + + if !found { + c.services[service] = append(c.services[service], t) + return true + } + + return false +} + +// UnregisterServiceTarget removes a routing target from the given service +func (c *Client) UnregisterServiceTarget(service string, t NATTarget) bool { + c.lock.Lock() + defer c.lock.Unlock() + + var tmp []NATTarget + for _, et := range c.services[service] { + if !et.equals(t) { + tmp = append(tmp, et) + } + } + + if len(tmp) == len(c.services[service]) { + return false + } + + c.services[service] = tmp + return true +} + +func (c *Client) buildServiceTable(service string, cType chainType) (rules [][]string) { + weightLeft := 0.0 + for _, nt := range c.services[service] { + weightLeft += nt.Weight + } + + for _, nt := range c.services[service] { + switch cType { + case chainTypeDNAT: + rules = append(rules, []string{ + "-m", "statistic", + "--mode", "random", + "--probability", strconv.FormatFloat(nt.Weight/weightLeft, 'f', probPrecision, probBitsize), + + "-p", nt.Proto, + "-d", nt.LocalAddr, + "--dport", strconv.Itoa(nt.LocalPort), + + "-j", "DNAT", + "--to-destination", fmt.Sprintf("%s:%d", nt.Addr, nt.Port), + }) + + case chainTypeSNAT: + rules = append(rules, []string{ + "-p", nt.Proto, + "-d", nt.Addr, + "--dport", strconv.Itoa(nt.Port), + + "-j", "SNAT", + "--to-source", nt.LocalAddr, + }) + } + + weightLeft -= nt.Weight + } + + rules = append(rules, []string{"-j", "RETURN"}) + + return rules +} + +func (c *Client) ensureChainWithRules(chain string, rules [][]string) error { + chainExists, err := c.ChainExists(natTable, chain) + if err != nil { + return fmt.Errorf("checking for chain existence: %w", err) + } + + if chainExists { + if err = c.ClearChain(natTable, chain); err != nil { + return fmt.Errorf("clearing existing chain: %w", err) + } + } else { + if err = c.NewChain(natTable, chain); err != nil { + return fmt.Errorf("creating tmp-chain: %w", err) + } + } + + for _, rule := range rules { + if err = c.Append(natTable, chain, rule...); err != nil { + return fmt.Errorf("adding rule to chain: %w", err) + } + } + + return nil +} + +func (*Client) tableName(components ...string) string { + var parts []string + for _, c := range components { + parts = append(parts, disallowedChars.ReplaceAllString(strings.ToUpper(c), "_")) + } + + return strings.Join(parts, "_") +} + +func (n NATTarget) equals(c NATTarget) bool { + nh, _ := hashstructure.Hash(n, hashstructure.FormatV2, nil) + ch, _ := hashstructure.Hash(c, hashstructure.FormatV2, nil) + + return nh == ch +} diff --git a/pkg/servicemonitor/monitor.go b/pkg/servicemonitor/monitor.go new file mode 100644 index 0000000..08224cf --- /dev/null +++ b/pkg/servicemonitor/monitor.go @@ -0,0 +1,115 @@ +// Package servicemonitor contains the monitoring logic which then +// triggers a rebuild of the chain in case there is a change +package servicemonitor + +import ( + "fmt" + "sync" + "time" + + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/config" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/healthcheck" + "git.luzifer.io/luzifer/ipt-loadbalancer/pkg/iptables" + "github.com/sirupsen/logrus" +) + +type ( + // Monitor contains the monitoring logic and state + Monitor struct { + ipt *iptables.Client + logger *logrus.Entry + svc config.Service + } +) + +// New creates a new monitor with empty rule set +func New(ipt *iptables.Client, logger *logrus.Entry, svc config.Service) *Monitor { + return &Monitor{ + ipt: ipt, + logger: logger, + svc: svc, + } +} + +// Run contains the monitoring loop for the given service and should +// run in the background. When returning an error the loop is stopped. +func (m Monitor) Run() (err error) { + for { + itStart := time.Now() + + checker := healthcheck.ByName(m.svc.HealthCheck.Type) + if checker == nil { + return fmt.Errorf("checker %q not found", m.svc.HealthCheck.Type) + } + + if err = m.updateRoutingTargets(checker); err != nil { + return fmt.Errorf("updating healthy targets: %w", err) + } + + time.Sleep(m.svc.HealthCheck.Interval - time.Since(itStart)) + } +} + +func (m Monitor) updateRoutingTargets(checker healthcheck.Checker) (err error) { + var ( + down, up []string + + changed bool + wg sync.WaitGroup + ) + wg.Add(len(m.svc.Targets)) + + for i := range m.svc.Targets { + t := m.svc.Targets[i] + logger := m.logger.WithField("target", fmt.Sprintf("%s:%d", t.Addr, t.Port)) + go func() { + defer wg.Done() + + tgt := iptables.NATTarget{ + Addr: t.Addr, + LocalAddr: m.svc.LocalAddr, + LocalPort: m.svc.LocalPort, + Port: t.Port, + Weight: float64(t.Weight), + Proto: m.svc.Proto(), + } + + if err := checker.Check(m.svc.HealthCheck.Settings, t); err != nil { + logger.WithError(err).Debug("detected target down") + changed = changed || m.ipt.UnregisterServiceTarget(m.svc.Name, tgt) + down = append(down, t.String()) + return + } + + logger.Debug("target up") + changed = changed || m.ipt.RegisterServiceTarget(m.svc.Name, tgt) + up = append(up, t.String()) + }() + } + + wg.Wait() + + uplog := m.logger.WithFields(logrus.Fields{ + "down": down, + "up": up, + }) + + switch { + case len(up) == len(up)+len(down): + uplog.Debugf("%d/%d targets up", len(up), len(up)+len(down)) + case len(up) > 0: + uplog.Warnf("%d/%d targets up", len(up), len(up)+len(down)) + case len(up) == 0: + uplog.Errorf("%d/%d targets up", len(up), len(up)+len(down)) + } + + if !changed { + return nil + } + + if err = m.ipt.EnsureManagedChains(); err != nil { + return fmt.Errorf("updating chains: %w", err) + } + + return nil +}