diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index aa5cdc7..212b3c7 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -4,6 +4,7 @@ package iptables import ( "fmt" + "net" "regexp" "strconv" "strings" @@ -11,6 +12,7 @@ import ( coreosIptables "github.com/coreos/go-iptables/iptables" "github.com/mitchellh/hashstructure/v2" + "github.com/sirupsen/logrus" ) const ( @@ -175,6 +177,26 @@ func (c *Client) buildServiceTable(service string, cType chainType) (rules [][]s } for _, nt := range c.services[service] { + var ( + bindAddr, localAddr, targetAddr string + err error + ) + + if bindAddr, err = c.translateToIP(nt.BindAddr); err != nil { + logrus.WithError(err).WithField("bind_addr", nt.BindAddr).Error("invalid address") + continue + } + + if targetAddr, err = c.translateToIP(nt.Addr); err != nil { + logrus.WithError(err).WithField("target_addr", nt.Addr).Error("invalid address") + continue + } + + if localAddr, err = c.translateToIP(nt.LocalAddr); err != nil { + logrus.WithError(err).WithField("local_addr", nt.LocalAddr).Error("invalid address") + continue + } + switch cType { case chainTypeDNAT: rules = append(rules, []string{ @@ -183,21 +205,21 @@ func (c *Client) buildServiceTable(service string, cType chainType) (rules [][]s "--probability", strconv.FormatFloat(nt.Weight/weightLeft, 'f', probPrecision, probBitsize), "-p", nt.Proto, - "-d", nt.BindAddr, + "-d", bindAddr, "--dport", strconv.Itoa(nt.BindPort), "-j", "DNAT", - "--to-destination", fmt.Sprintf("%s:%d", nt.Addr, nt.Port), + "--to-destination", fmt.Sprintf("%s:%d", targetAddr, nt.Port), }) case chainTypeSNAT: rules = append(rules, []string{ "-p", nt.Proto, - "-d", nt.Addr, + "-d", targetAddr, "--dport", strconv.Itoa(nt.Port), "-j", "SNAT", - "--to-source", nt.LocalAddr, + "--to-source", localAddr, }) } @@ -243,6 +265,29 @@ func (*Client) tableName(components ...string) string { return strings.Join(parts, "_") } +func (*Client) translateToIP(addr string) (string, error) { + ip := net.ParseIP(addr) + if ip != nil { + // We got either valid IPv4 or IPv6: Just return that. + return ip.String(), nil + } + + // Was no IP, might be a hostname: Look it up + ips, err := net.LookupIP(addr) + if err != nil { + // Definitely was none. + return "", fmt.Errorf("resolving %q to ip: %w", addr, err) + } + + if len(ips) == 0 { + // Maybe was one but had no addresses. + return "", fmt.Errorf("resolving %q did not yield IPs", addr) + } + + // Had one or more addresses, we take the first one + return ips[0].String(), nil +} + func (n NATTarget) equals(c NATTarget) bool { nh, _ := hashstructure.Hash(n, hashstructure.FormatV2, nil) ch, _ := hashstructure.Hash(c, hashstructure.FormatV2, nil)