diff --git a/go.mod b/go.mod index 69befd4..6bec18e 100644 --- a/go.mod +++ b/go.mod @@ -7,12 +7,13 @@ require ( github.com/Luzifer/rconfig/v2 v2.4.0 github.com/hashicorp/vault/api v1.10.0 github.com/mitchellh/go-homedir v1.1.0 - github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 ) require ( github.com/cenkalti/backoff/v3 v3.2.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-jose/go-jose/v3 v3.0.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect @@ -24,6 +25,7 @@ require ( github.com/hashicorp/go-sockaddr v1.0.6 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.16.0 // indirect @@ -33,4 +35,5 @@ require ( golang.org/x/time v0.5.0 // indirect gopkg.in/validator.v2 v2.0.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 52057a6..317cb68 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,6 @@ github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= diff --git a/main.go b/main.go index c402f6e..f4d61aa 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,6 @@ import ( "os" "os/exec" "strings" - "sync" "github.com/hashicorp/vault/api" "github.com/mitchellh/go-homedir" @@ -187,7 +186,21 @@ func main() { return } - obfuscate := prepareObfuscator(envData) + var ( + stdoutObfuscate = newObfuscator(os.Stdout, envData, getReplaceFn(cfg.Obfuscate)) + stderrObfuscate = newObfuscator(os.Stderr, envData, getReplaceFn(cfg.Obfuscate)) + ) + + defer func() { + if err := stderrObfuscate.Close(); err != nil { + logrus.WithError(err).Error("closing stderr") + } + }() + defer func() { + if err := stdoutObfuscate.Close(); err != nil { + logrus.WithError(err).Error("closing stdout") + } + }() emap := env.ListToMap(os.Environ()) for k, v := range emap { @@ -200,40 +213,10 @@ func main() { cmd.Stdin = os.Stdin cmd.Env = env.MapToList(envData) - stderr, err := cmd.StderrPipe() - if err != nil { - logrus.WithError(err).Fatal("getting stderr pipe") - } + cmd.Stderr = stderrObfuscate + cmd.Stdout = stdoutObfuscate - stdout, err := cmd.StdoutPipe() - if err != nil { - logrus.WithError(err).Fatal("getting stdout pipe") - } - - if err := cmd.Start(); err != nil { - logrus.WithError(err).Fatal("starting command") - } - - wg := new(sync.WaitGroup) - wg.Add(2) //nolint:gomnd - - go func() { - defer wg.Done() - if err := obfuscationTransport(stdout, os.Stdout, obfuscate); err != nil { - logrus.WithError(err).Error("obfuscating stdout") - } - }() - - go func() { - defer wg.Done() - if err := obfuscationTransport(stderr, os.Stderr, obfuscate); err != nil { - logrus.WithError(err).Error("obfuscating stderr") - } - }() - - wg.Wait() - - if err := cmd.Wait(); err != nil { - logrus.WithError(err).Fatal("error during command execution") + if err := cmd.Run(); err != nil { + logrus.WithError(err).Fatal("running command") } } diff --git a/obfuscator.go b/obfuscator.go index 0f3a8dd..1c884dc 100644 --- a/obfuscator.go +++ b/obfuscator.go @@ -1,48 +1,28 @@ package main import ( - "bufio" "crypto/sha256" "fmt" - "io" - "strings" - - "github.com/pkg/errors" ) -func prepareObfuscator(secrets map[string]string) func(string) string { - var prepare func(name, secret string) string +func replaceAsterisk(_, _ string) string { return "****" } +func replaceHash(_, secret string) string { + return fmt.Sprintf("sha256:%x", sha256.Sum256([]byte(secret))) +} +func replaceName(name, _ string) string { return name } - switch cfg.Obfuscate { +func getReplaceFn(name string) replaceFn { + switch name { case "asterisk": - prepare = func(name, secret string) string { return "****" } + return replaceAsterisk case "hash": - prepare = func(name, secret string) string { return fmt.Sprintf("sha256:%x", sha256.Sum256([]byte(secret))) } + return replaceHash case "name": - prepare = func(name, secret string) string { return name } + return replaceName default: - return func(in string) string { return in } + return nil } - - replacements := []string{} - - for k, v := range secrets { - if k != "" && v != "" { - replacements = append(replacements, v, prepare(k, v)) - } - } - repl := strings.NewReplacer(replacements...) - - return func(in string) string { return repl.Replace(in) } -} - -func obfuscationTransport(in io.Reader, out io.Writer, obfuscate func(string) string) error { - s := bufio.NewScanner(in) - for s.Scan() { - fmt.Fprintln(out, obfuscate(s.Text())) - } - return errors.Wrapf(s.Err(), "Failed to scan in buffer") } diff --git a/obfuscator_writer.go b/obfuscator_writer.go new file mode 100644 index 0000000..c76f6f7 --- /dev/null +++ b/obfuscator_writer.go @@ -0,0 +1,112 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "sort" +) + +type ( + obfuscator struct { + buffer []byte + closed bool + longestSecretLen int + output io.Writer + secretReplacements [][2]string + } + + replaceFn func(name, secret string) string +) + +var _ io.WriteCloser = &obfuscator{} + +func newObfuscator(output io.Writer, secrets map[string]string, fn replaceFn) *obfuscator { + // We're looking for the longest secret: That is half the amount of + // data we need to keep in the buffer in order to detect and replace + // the secrets before forwarding the data to the real writer + var longestSecretLen int + for _, s := range secrets { + if l := len(s); l > longestSecretLen { + longestSecretLen = l + } + } + + var replacements [][2]string + if fn == nil { + // Special case: No replacer is set, we can pass-through + longestSecretLen = 0 + } else { + // Create a map of replacements + for name, secret := range secrets { + replacements = append(replacements, [2]string{secret, fn(name, secret)}) + } + } + + sort.Slice(replacements, func(j, i int) bool { + return len(replacements[i][0]) < len(replacements[j][0]) + }) + + return &obfuscator{ + longestSecretLen: longestSecretLen, + output: output, + secretReplacements: replacements, + } +} + +func (o *obfuscator) Close() (err error) { + o.closed = true + + // Do a last sweep on the remaining buffer + o.sanitizeBuffer() + + // Copy the rest to the underlying writer + if _, err = o.output.Write(o.buffer); err != nil { + return fmt.Errorf("writing remaining buffer: %w", err) + } + + return nil +} + +func (o *obfuscator) Write(data []byte) (n int, err error) { + if o.closed { + return 0, fmt.Errorf("write on closed writer") + } + + // First take everything from the input + o.buffer = append(o.buffer, data...) + + // If we haven't enough data buffered lets just pretent we wrote + // everything and in reality do nothing + if len(o.buffer) < o.longestSecretLen*2 { + return len(data), nil + } + + // Now we have at least twice the length of the longest secret in + // the buffer so we can sanitize the buffer… + o.sanitizeBuffer() + + // Now that all secrets have been replaced, we can write everything + // to the writer except the last {longestSecretLen} bytes as they + // might contain a part of the longest secret + wrLen := len(o.buffer) - o.longestSecretLen + if wrLen < 1 { + // Nothing to write, buffer was shortened too much + return len(data), nil + } + + if _, err = io.Copy(o.output, bytes.NewReader(o.buffer[:wrLen])); err != nil { + return 0, fmt.Errorf("copying sanitized data to writer: %w", err) + } + + o.buffer = o.buffer[wrLen:] + + // We took everything from them: Lets tell them we wrote everything + return len(data), nil +} + +func (o *obfuscator) sanitizeBuffer() { + for _, repl := range o.secretReplacements { + o.buffer = bytes.ReplaceAll(o.buffer, []byte(repl[0]), []byte(repl[1])) + } +} diff --git a/obfuscator_writer_test.go b/obfuscator_writer_test.go new file mode 100644 index 0000000..9a9fe8e --- /dev/null +++ b/obfuscator_writer_test.go @@ -0,0 +1,71 @@ +package main + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestObfuscation(t *testing.T) { + cases := []struct { + Input string + Expected string + ReplaceFn replaceFn + }{ + { + Input: "this is a longer string with a secret embedded inside", + Expected: "this is a longer string with a **** embedded inside", + ReplaceFn: replaceAsterisk, + }, + { + Input: "this is very short", + Expected: "this is very short", + ReplaceFn: replaceAsterisk, + }, + { + Input: "secret", + Expected: "****", + ReplaceFn: replaceAsterisk, + }, + { + Input: "foo", + Expected: "foo", + ReplaceFn: replaceAsterisk, + }, + { + Input: "can we have this very long secret with some special #$% chars in it obfuscated?", + Expected: "can we have **** obfuscated?", + ReplaceFn: replaceAsterisk, + }, + { + Input: "secretsecret", + Expected: "********", + ReplaceFn: replaceAsterisk, + }, + { + Input: "secret", + Expected: "secret", + ReplaceFn: nil, // Direct pass-through + }, + } + + for _, c := range cases { + t.Run(c.Input, func(t *testing.T) { + out := new(bytes.Buffer) + obf := newObfuscator(out, map[string]string{ + "mysecret": "secret", + "longsecret": "this very long secret with some special #$% chars in it", + }, c.ReplaceFn) + + _, err := fmt.Fprint(obf, c.Input) + require.NoError(t, err) + + require.NoError(t, obf.Close()) + + assert.Equal(t, c.Expected, out.String()) + }) + } +}