mirror of
https://github.com/Luzifer/ansible-role-version.git
synced 2024-12-23 11:01:20 +00:00
365 lines
8.8 KiB
Go
365 lines
8.8 KiB
Go
|
package ssh_config
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func loadFile(t *testing.T, filename string) []byte {
|
||
|
data, err := ioutil.ReadFile(filename)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
return data
|
||
|
}
|
||
|
|
||
|
var files = []string{"testdata/config1", "testdata/config2"}
|
||
|
|
||
|
func TestDecode(t *testing.T) {
|
||
|
for _, filename := range files {
|
||
|
data := loadFile(t, filename)
|
||
|
cfg, err := Decode(bytes.NewReader(data))
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
out := cfg.String()
|
||
|
if out != string(data) {
|
||
|
t.Errorf("out != data: out: %q\ndata: %q", out, string(data))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testConfigFinder(filename string) func() string {
|
||
|
return func() string { return filename }
|
||
|
}
|
||
|
|
||
|
func nullConfigFinder() string {
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
func TestGet(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/config1"),
|
||
|
}
|
||
|
|
||
|
val := us.Get("wap", "User")
|
||
|
if val != "root" {
|
||
|
t.Errorf("expected to find User root, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetWithDefault(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/config1"),
|
||
|
}
|
||
|
|
||
|
val, err := us.GetStrict("wap", "PasswordAuthentication")
|
||
|
if err != nil {
|
||
|
t.Fatalf("expected nil err, got %v", err)
|
||
|
}
|
||
|
if val != "yes" {
|
||
|
t.Errorf("expected to get PasswordAuthentication yes, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetInvalidPort(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/invalid-port"),
|
||
|
}
|
||
|
|
||
|
val, err := us.GetStrict("test.test", "Port")
|
||
|
if err == nil {
|
||
|
t.Fatalf("expected non-nil err, got nil")
|
||
|
}
|
||
|
if val != "" {
|
||
|
t.Errorf("expected to get '' for val, got %q", val)
|
||
|
}
|
||
|
if err.Error() != `ssh_config: strconv.ParseUint: parsing "notanumber": invalid syntax` {
|
||
|
t.Errorf("wrong error: got %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetNotFoundNoDefault(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/config1"),
|
||
|
}
|
||
|
|
||
|
val, err := us.GetStrict("wap", "CanonicalDomains")
|
||
|
if err != nil {
|
||
|
t.Fatalf("expected nil err, got %v", err)
|
||
|
}
|
||
|
if val != "" {
|
||
|
t.Errorf("expected to get CanonicalDomains '', got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetWildcard(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/config3"),
|
||
|
}
|
||
|
|
||
|
val := us.Get("bastion.stage.i.us.example.net", "Port")
|
||
|
if val != "22" {
|
||
|
t.Errorf("expected to find Port 22, got %q", val)
|
||
|
}
|
||
|
|
||
|
val = us.Get("bastion.net", "Port")
|
||
|
if val != "25" {
|
||
|
t.Errorf("expected to find Port 24, got %q", val)
|
||
|
}
|
||
|
|
||
|
val = us.Get("10.2.3.4", "Port")
|
||
|
if val != "23" {
|
||
|
t.Errorf("expected to find Port 23, got %q", val)
|
||
|
}
|
||
|
val = us.Get("101.2.3.4", "Port")
|
||
|
if val != "25" {
|
||
|
t.Errorf("expected to find Port 24, got %q", val)
|
||
|
}
|
||
|
val = us.Get("20.20.20.4", "Port")
|
||
|
if val != "24" {
|
||
|
t.Errorf("expected to find Port 24, got %q", val)
|
||
|
}
|
||
|
val = us.Get("20.20.20.20", "Port")
|
||
|
if val != "25" {
|
||
|
t.Errorf("expected to find Port 25, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetExtraSpaces(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/extraspace"),
|
||
|
}
|
||
|
|
||
|
val := us.Get("test.test", "Port")
|
||
|
if val != "1234" {
|
||
|
t.Errorf("expected to find Port 1234, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetCaseInsensitive(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/config1"),
|
||
|
}
|
||
|
|
||
|
val := us.Get("wap", "uSER")
|
||
|
if val != "root" {
|
||
|
t.Errorf("expected to find User root, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetEmpty(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: nullConfigFinder,
|
||
|
systemConfigFinder: nullConfigFinder,
|
||
|
}
|
||
|
val, err := us.GetStrict("wap", "User")
|
||
|
if err != nil {
|
||
|
t.Errorf("expected nil error, got %v", err)
|
||
|
}
|
||
|
if val != "" {
|
||
|
t.Errorf("expected to get empty string, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetEqsign(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/eqsign"),
|
||
|
}
|
||
|
|
||
|
val := us.Get("test.test", "Port")
|
||
|
if val != "1234" {
|
||
|
t.Errorf("expected to find Port 1234, got %q", val)
|
||
|
}
|
||
|
val = us.Get("test.test", "Port2")
|
||
|
if val != "5678" {
|
||
|
t.Errorf("expected to find Port2 5678, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var includeFile = []byte(`
|
||
|
# This host should not exist, so we can use it for test purposes / it won't
|
||
|
# interfere with any other configurations.
|
||
|
Host kevinburke.ssh_config.test.example.com
|
||
|
Port 4567
|
||
|
`)
|
||
|
|
||
|
func TestInclude(t *testing.T) {
|
||
|
if testing.Short() {
|
||
|
t.Skip("skipping fs write in short mode")
|
||
|
}
|
||
|
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-test-file")
|
||
|
err := ioutil.WriteFile(testPath, includeFile, 0644)
|
||
|
if err != nil {
|
||
|
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||
|
}
|
||
|
defer os.Remove(testPath)
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/include"),
|
||
|
}
|
||
|
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
|
||
|
if val != "4567" {
|
||
|
t.Errorf("expected to find Port=4567 in included file, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestIncludeSystem(t *testing.T) {
|
||
|
if testing.Short() {
|
||
|
t.Skip("skipping fs write in short mode")
|
||
|
}
|
||
|
testPath := filepath.Join("/", "etc", "ssh", "kevinburke-ssh-config-test-file")
|
||
|
err := ioutil.WriteFile(testPath, includeFile, 0644)
|
||
|
if err != nil {
|
||
|
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||
|
}
|
||
|
defer os.Remove(testPath)
|
||
|
us := &UserSettings{
|
||
|
systemConfigFinder: testConfigFinder("testdata/include"),
|
||
|
}
|
||
|
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
|
||
|
if val != "4567" {
|
||
|
t.Errorf("expected to find Port=4567 in included file, got %q", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var recursiveIncludeFile = []byte(`
|
||
|
Host kevinburke.ssh_config.test.example.com
|
||
|
Include kevinburke-ssh-config-recursive-include
|
||
|
`)
|
||
|
|
||
|
func TestIncludeRecursive(t *testing.T) {
|
||
|
if testing.Short() {
|
||
|
t.Skip("skipping fs write in short mode")
|
||
|
}
|
||
|
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-recursive-include")
|
||
|
err := ioutil.WriteFile(testPath, recursiveIncludeFile, 0644)
|
||
|
if err != nil {
|
||
|
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||
|
}
|
||
|
defer os.Remove(testPath)
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/include-recursive"),
|
||
|
}
|
||
|
val, err := us.GetStrict("kevinburke.ssh_config.test.example.com", "Port")
|
||
|
if err != ErrDepthExceeded {
|
||
|
t.Errorf("Recursive include: expected ErrDepthExceeded, got %v", err)
|
||
|
}
|
||
|
if val != "" {
|
||
|
t.Errorf("non-empty string value %s", val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestIncludeString(t *testing.T) {
|
||
|
if testing.Short() {
|
||
|
t.Skip("skipping fs write in short mode")
|
||
|
}
|
||
|
data, err := ioutil.ReadFile("testdata/include")
|
||
|
if err != nil {
|
||
|
log.Fatal(err)
|
||
|
}
|
||
|
c, err := Decode(bytes.NewReader(data))
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
s := c.String()
|
||
|
if s != string(data) {
|
||
|
t.Errorf("mismatch: got %q\nwant %q", s, string(data))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var matchTests = []struct {
|
||
|
in []string
|
||
|
alias string
|
||
|
want bool
|
||
|
}{
|
||
|
{[]string{"*"}, "any.test", true},
|
||
|
{[]string{"a", "b", "*", "c"}, "any.test", true},
|
||
|
{[]string{"a", "b", "c"}, "any.test", false},
|
||
|
{[]string{"any.test"}, "any1test", false},
|
||
|
{[]string{"192.168.0.?"}, "192.168.0.1", true},
|
||
|
{[]string{"192.168.0.?"}, "192.168.0.10", false},
|
||
|
{[]string{"*.co.uk"}, "bbc.co.uk", true},
|
||
|
{[]string{"*.co.uk"}, "subdomain.bbc.co.uk", true},
|
||
|
{[]string{"*.*.co.uk"}, "bbc.co.uk", false},
|
||
|
{[]string{"*.*.co.uk"}, "subdomain.bbc.co.uk", true},
|
||
|
{[]string{"*.example.com", "!*.dialup.example.com", "foo.dialup.example.com"}, "foo.dialup.example.com", false},
|
||
|
{[]string{"test.*", "!test.host"}, "test.host", false},
|
||
|
}
|
||
|
|
||
|
func TestMatches(t *testing.T) {
|
||
|
for _, tt := range matchTests {
|
||
|
patterns := make([]*Pattern, len(tt.in))
|
||
|
for i := range tt.in {
|
||
|
pat, err := NewPattern(tt.in[i])
|
||
|
if err != nil {
|
||
|
t.Fatalf("error compiling pattern %s: %v", tt.in[i], err)
|
||
|
}
|
||
|
patterns[i] = pat
|
||
|
}
|
||
|
host := &Host{
|
||
|
Patterns: patterns,
|
||
|
}
|
||
|
got := host.Matches(tt.alias)
|
||
|
if got != tt.want {
|
||
|
t.Errorf("host(%q).Matches(%q): got %v, want %v", tt.in, tt.alias, got, tt.want)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMatchUnsupported(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/match-directive"),
|
||
|
}
|
||
|
|
||
|
_, err := us.GetStrict("test.test", "Port")
|
||
|
if err == nil {
|
||
|
t.Fatal("expected Match directive to error, didn't")
|
||
|
}
|
||
|
if !strings.Contains(err.Error(), "ssh_config: Match directive parsing is unsupported") {
|
||
|
t.Errorf("wrong error: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestIndexInRange(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/config4"),
|
||
|
}
|
||
|
|
||
|
user, err := us.GetStrict("wap", "User")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if user != "root" {
|
||
|
t.Errorf("expected User to be %q, got %q", "root", user)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestDosLinesEndingsDecode(t *testing.T) {
|
||
|
us := &UserSettings{
|
||
|
userConfigFinder: testConfigFinder("testdata/dos-lines"),
|
||
|
}
|
||
|
|
||
|
user, err := us.GetStrict("wap", "User")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
if user != "root" {
|
||
|
t.Errorf("expected User to be %q, got %q", "root", user)
|
||
|
}
|
||
|
|
||
|
host, err := us.GetStrict("wap2", "HostName")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
if host != "8.8.8.8" {
|
||
|
t.Errorf("expected HostName to be %q, got %q", "8.8.8.8", host)
|
||
|
}
|
||
|
}
|