commit b3fc17cb4ee87a4672c6c5a7588d1573a391dcae Author: Knut Ahlers Date: Thu Nov 28 18:09:18 2019 +0100 Initial version diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..04bad90 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.env +id_rsa +id_rsa.pub +shareport diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c1e0672 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module github.com/Luzifer/shareport + +go 1.13 + +require ( + github.com/Luzifer/rconfig/v2 v2.2.1 + github.com/pkg/errors v0.8.1 + github.com/sirupsen/logrus v1.4.2 + golang.org/x/crypto v0.0.0-20191122220453-ac88ee75c92c +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..62b5e46 --- /dev/null +++ b/go.sum @@ -0,0 +1,29 @@ +github.com/Luzifer/rconfig v2.2.0+incompatible h1:Kle3+rshPM7LxciOheaR4EfHUzibkDDGws04sefQ5m8= +github.com/Luzifer/rconfig/v2 v2.2.1 h1:zcDdLQlnlzwcBJ8E0WFzOkQE1pCMn3EbX0dFYkeTczg= +github.com/Luzifer/rconfig/v2 v2.2.1/go.mod h1:OKIX0/JRZrPJ/ZXXWklQEFXA6tBfWaljZbW37w+sqBw= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191122220453-ac88ee75c92c h1:/nJuwDLoL/zrqY6gf57vxC+Pi+pZ8bfhpPkicO5H7W4= +golang.org/x/crypto v0.0.0-20191122220453-ac88ee75c92c/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/validator.v2 v2.0.0-20180514200540-135c24b11c19 h1:WB265cn5OpO+hK3pikC9hpP1zI/KTwmyMFKloW9eOVc= +gopkg.in/validator.v2 v2.0.0-20180514200540-135c24b11c19/go.mod h1:o4V0GXN9/CAmCsvJ0oXYZvrZOe7syiDZSN1GWGZTGzc= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/main.go b/main.go new file mode 100644 index 0000000..8204667 --- /dev/null +++ b/main.go @@ -0,0 +1,223 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "os/signal" + "os/user" + "path" + "syscall" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + "github.com/Luzifer/rconfig/v2" +) + +var ( + cfg = struct { + DebugRemote bool `flag:"debug-remote" default:"false" description:"Send remote stderr local terminal"` + IdentityFile string `flag:"identity-file,i" vardefault:"ssh_key" description:"Identity file to use for connecting to the remote"` + IdentityFilePassword string `flag:"identity-file-password" default:"" description:"Password for the identity file"` + LocalAddr string `flag:"local-addr,l" default:"" description:"Local address / port to forward" validate:"nonzero"` + LogLevel string `flag:"log-level" default:"info" description:"Log level (debug, info, warn, error, fatal)"` + RemoteHost string `flag:"remote-host" default:"" description:"Remote host and port in format host:port" validate:"nonzero"` + RemoteCommand string `flag:"remote-command" default:"" description:"Remote command to execute after connect"` + RemoteListen string `flag:"remote-listen" default:"localhost:0" description:"Address to listen on remote (port is available in script)"` + RemoteScript string `flag:"remote-script" default:"" description:"Bash script to push and execute (overwrites remote-command)"` + RemoteUser string `flag:"remote-user" vardefault:"remote_user" description:"User to use to connect to remote host"` + VersionAndExit bool `flag:"version" default:"false" description:"Prints current version and exits"` + }{} + + running = true + + version = "dev" +) + +func forward(remoteConn net.Conn) { + defer remoteConn.Close() + + localConn, err := net.Dial("tcp", cfg.LocalAddr) + if err != nil { + log.WithError(err).Error("Unable to connect to local address") + return + } + defer localConn.Close() + + copyConn := func(w, r net.Conn, wg chan struct{}) { + _, err := io.Copy(w, r) + if err != nil { + log.WithError(err).Debug("IO copy caused an error, terminating connection") + } + wg <- struct{}{} + } + + var wg = make(chan struct{}, 2) + + go copyConn(localConn, remoteConn, wg) + go copyConn(remoteConn, localConn, wg) + + <-wg +} + +func genDefaults() map[string]string { + defs := map[string]string{} + + if userHome, err := os.UserHomeDir(); err == nil { + defs["ssh_key"] = path.Join(userHome, ".ssh", "id_rsa") + } + + if user, err := user.Current(); err == nil { + defs["remote_user"] = user.Username + } + + return defs +} + +func init() { + rconfig.SetVariableDefaults(genDefaults()) + + rconfig.AutoEnv(true) + if err := rconfig.ParseAndValidate(&cfg); err != nil { + log.Fatalf("Unable to parse commandline options: %s", err) + } + + if cfg.VersionAndExit { + fmt.Printf("shareport %s\n", version) + os.Exit(0) + } + + if l, err := log.ParseLevel(cfg.LogLevel); err != nil { + log.WithError(err).Fatal("Unable to parse log level") + } else { + log.SetLevel(l) + } +} + +func loadPrivateKey() (ssh.AuthMethod, error) { + kf, err := ioutil.ReadFile(cfg.IdentityFile) + if err != nil { + return nil, errors.Wrap(err, "Unable to read key file") + } + + pk, err := signerFromPem(kf, []byte(cfg.IdentityFilePassword)) + return ssh.PublicKeys(pk), errors.Wrap(err, "Unable to parse private key") +} + +func main() { + sigC := make(chan os.Signal) + signal.Notify(sigC, syscall.SIGINT, syscall.SIGTERM) + + privateKey, err := loadPrivateKey() + if err != nil { + log.WithError(err).Fatal("Unable to load key") + } + + config := &ssh.ClientConfig{ + User: cfg.RemoteUser, + Auth: []ssh.AuthMethod{privateKey}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + // Connect to remote + client, err := ssh.Dial("tcp", cfg.RemoteHost, config) + if err != nil { + log.WithError(err).Fatal("Unable to connect to remote host") + } + + // Open port for us to listen on + remoteListener, err := client.Listen("tcp", cfg.RemoteListen) + if err != nil { + log.WithError(err).Fatal("Unable to listen for connection") + } + defer remoteListener.Close() + + _, port, err := net.SplitHostPort(remoteListener.Addr().String()) + if err != nil { + log.WithError(err).Fatal("Unable to get port of remote listen socket") + } + + log.WithField("port", port).Debug("Remote port established") + + go func() { + for running { + remoteConn, err := remoteListener.Accept() + if err != nil { + log.WithError(err).Error("Unable to accept remote connection") + continue + } + + go forward(remoteConn) + } + }() + + // Initialize script + var scriptIn = new(bytes.Buffer) + fmt.Fprintln(scriptIn, "set -euxo pipefail") + + // Create remote script session + session, err := client.NewSession() + if err != nil { + log.WithError(err).Fatal("Unable to open remote session") + } + defer session.Close() + + for k, v := range map[string]string{ + "PORT": port, + "LISTEN": remoteListener.Addr().String(), + } { + fmt.Fprintf(scriptIn, "export %s=%q\n", k, v) + } + + switch { + case cfg.RemoteScript != "": + script, err := ioutil.ReadFile(cfg.RemoteScript) + if err != nil { + log.WithError(err).Fatal("Unable to load remote-script") + } + scriptIn.Write(script) + + case cfg.RemoteCommand != "": + fmt.Fprintf(scriptIn, "exec %s", cfg.RemoteCommand) + + default: + log.Fatal("Neither remote-command nor remote-script specified") + } + + if cfg.DebugRemote { + session.Stderr = os.Stderr + } else { + session.Stderr = ioutil.Discard + } + + session.Stdin = scriptIn + session.Stdout = os.Stdout + + if err := session.Start("/bin/bash -euxo pipefail"); err != nil { + log.WithError(err).Fatal("Unable to spawn remote command") + } + + go func() { + if err := session.Wait(); err != nil { + log.WithError(err).Error("Remote process caused an error") + } + sigC <- syscall.SIGINT + }() + + for { + select { + case <-sigC: + log.Info("Signal triggered, shutting down") + if err := session.Signal(ssh.SIGHUP); err != nil { + log.WithError(err).Error("Unable to send TERM signal to remote process") + } + running = false + return + } + } +} diff --git a/ssh_key.go b/ssh_key.go new file mode 100644 index 0000000..b281ae2 --- /dev/null +++ b/ssh_key.go @@ -0,0 +1,79 @@ +package main + +import ( + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + + "golang.org/x/crypto/ssh" +) + +func signerFromPem(pemBytes []byte, password []byte) (ssh.Signer, error) { + + // read pem block + err := errors.New("Pem decode failed, no key found") + pemBlock, _ := pem.Decode(pemBytes) + if pemBlock == nil { + return nil, err + } + + // handle encrypted key + if x509.IsEncryptedPEMBlock(pemBlock) { + // decrypt PEM + pemBlock.Bytes, err = x509.DecryptPEMBlock(pemBlock, password) + if err != nil { + return nil, fmt.Errorf("Decrypting PEM block failed %v", err) + } + + // get RSA, EC or DSA key + key, err := parsePemBlock(pemBlock) + if err != nil { + return nil, err + } + + // generate signer instance from key + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return nil, fmt.Errorf("Creating signer from encrypted key failed %v", err) + } + + return signer, nil + } else { + // generate signer instance from plain key + signer, err := ssh.ParsePrivateKey(pemBytes) + if err != nil { + return nil, fmt.Errorf("Parsing plain private key failed %v", err) + } + + return signer, nil + } +} + +func parsePemBlock(block *pem.Block) (interface{}, error) { + switch block.Type { + case "RSA PRIVATE KEY": + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("Parsing PKCS private key failed %v", err) + } else { + return key, nil + } + case "EC PRIVATE KEY": + key, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("Parsing EC private key failed %v", err) + } else { + return key, nil + } + case "DSA PRIVATE KEY": + key, err := ssh.ParseDSAPrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("Parsing DSA private key failed %v", err) + } else { + return key, nil + } + default: + return nil, fmt.Errorf("Parsing private key failed, unsupported key type %q", block.Type) + } +}