Add support for hostnames in addresses

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2024-04-15 22:27:13 +02:00
parent 9fdd7baa24
commit 5bf702a1a4
Signed by: luzifer
SSH key fingerprint: SHA256:/xtE5lCgiRDQr8SLxHMS92ZBlACmATUmF1crK16Ks4E

View file

@ -4,6 +4,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -11,6 +12,7 @@ import (
coreosIptables "github.com/coreos/go-iptables/iptables" coreosIptables "github.com/coreos/go-iptables/iptables"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
"github.com/sirupsen/logrus"
) )
const ( const (
@ -175,6 +177,26 @@ func (c *Client) buildServiceTable(service string, cType chainType) (rules [][]s
} }
for _, nt := range c.services[service] { for _, nt := range c.services[service] {
var (
bindAddr, localAddr, targetAddr string
err error
)
if bindAddr, err = c.translateToIP(nt.BindAddr); err != nil {
logrus.WithError(err).WithField("bind_addr", nt.BindAddr).Error("invalid address")
continue
}
if targetAddr, err = c.translateToIP(nt.Addr); err != nil {
logrus.WithError(err).WithField("target_addr", nt.Addr).Error("invalid address")
continue
}
if localAddr, err = c.translateToIP(nt.LocalAddr); err != nil {
logrus.WithError(err).WithField("local_addr", nt.LocalAddr).Error("invalid address")
continue
}
switch cType { switch cType {
case chainTypeDNAT: case chainTypeDNAT:
rules = append(rules, []string{ rules = append(rules, []string{
@ -183,21 +205,21 @@ func (c *Client) buildServiceTable(service string, cType chainType) (rules [][]s
"--probability", strconv.FormatFloat(nt.Weight/weightLeft, 'f', probPrecision, probBitsize), "--probability", strconv.FormatFloat(nt.Weight/weightLeft, 'f', probPrecision, probBitsize),
"-p", nt.Proto, "-p", nt.Proto,
"-d", nt.BindAddr, "-d", bindAddr,
"--dport", strconv.Itoa(nt.BindPort), "--dport", strconv.Itoa(nt.BindPort),
"-j", "DNAT", "-j", "DNAT",
"--to-destination", fmt.Sprintf("%s:%d", nt.Addr, nt.Port), "--to-destination", fmt.Sprintf("%s:%d", targetAddr, nt.Port),
}) })
case chainTypeSNAT: case chainTypeSNAT:
rules = append(rules, []string{ rules = append(rules, []string{
"-p", nt.Proto, "-p", nt.Proto,
"-d", nt.Addr, "-d", targetAddr,
"--dport", strconv.Itoa(nt.Port), "--dport", strconv.Itoa(nt.Port),
"-j", "SNAT", "-j", "SNAT",
"--to-source", nt.LocalAddr, "--to-source", localAddr,
}) })
} }
@ -243,6 +265,29 @@ func (*Client) tableName(components ...string) string {
return strings.Join(parts, "_") return strings.Join(parts, "_")
} }
func (*Client) translateToIP(addr string) (string, error) {
ip := net.ParseIP(addr)
if ip != nil {
// We got either valid IPv4 or IPv6: Just return that.
return ip.String(), nil
}
// Was no IP, might be a hostname: Look it up
ips, err := net.LookupIP(addr)
if err != nil {
// Definitely was none.
return "", fmt.Errorf("resolving %q to ip: %w", addr, err)
}
if len(ips) == 0 {
// Maybe was one but had no addresses.
return "", fmt.Errorf("resolving %q did not yield IPs", addr)
}
// Had one or more addresses, we take the first one
return ips[0].String(), nil
}
func (n NATTarget) equals(c NATTarget) bool { func (n NATTarget) equals(c NATTarget) bool {
nh, _ := hashstructure.Hash(n, hashstructure.FormatV2, nil) nh, _ := hashstructure.Hash(n, hashstructure.FormatV2, nil)
ch, _ := hashstructure.Hash(c, hashstructure.FormatV2, nil) ch, _ := hashstructure.Hash(c, hashstructure.FormatV2, nil)