// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssh

import (
	"bytes"
	"math/big"
	"math/rand"
	"reflect"
	"testing"
	"testing/quick"
)

var intLengthTests = []struct {
	val, length int
}{
	{0, 4 + 0},
	{1, 4 + 1},
	{127, 4 + 1},
	{128, 4 + 2},
	{-1, 4 + 1},
}

func TestIntLength(t *testing.T) {
	for _, test := range intLengthTests {
		v := new(big.Int).SetInt64(int64(test.val))
		length := intLength(v)
		if length != test.length {
			t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
		}
	}
}

type msgAllTypes struct {
	Bool    bool `sshtype:"21"`
	Array   [16]byte
	Uint64  uint64
	Uint32  uint32
	Uint8   uint8
	String  string
	Strings []string
	Bytes   []byte
	Int     *big.Int
	Rest    []byte `ssh:"rest"`
}

func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &msgAllTypes{}
	m.Bool = rand.Intn(2) == 1
	randomBytes(m.Array[:], rand)
	m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
	m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
	m.Uint8 = uint8(rand.Intn(1 << 8))
	m.String = string(m.Array[:])
	m.Strings = randomNameList(rand)
	m.Bytes = m.Array[:]
	m.Int = randomInt(rand)
	m.Rest = m.Array[:]
	return reflect.ValueOf(m)
}

func TestMarshalUnmarshal(t *testing.T) {
	rand := rand.New(rand.NewSource(0))
	iface := &msgAllTypes{}
	ty := reflect.ValueOf(iface).Type()

	n := 100
	if testing.Short() {
		n = 5
	}
	for j := 0; j < n; j++ {
		v, ok := quick.Value(ty, rand)
		if !ok {
			t.Errorf("failed to create value")
			break
		}

		m1 := v.Elem().Interface()
		m2 := iface

		marshaled := Marshal(m1)
		if err := Unmarshal(marshaled, m2); err != nil {
			t.Errorf("Unmarshal %#v: %s", m1, err)
			break
		}

		if !reflect.DeepEqual(v.Interface(), m2) {
			t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
			break
		}
	}
}

func TestUnmarshalEmptyPacket(t *testing.T) {
	var b []byte
	var m channelRequestSuccessMsg
	if err := Unmarshal(b, &m); err == nil {
		t.Fatalf("unmarshal of empty slice succeeded")
	}
}

func TestUnmarshalUnexpectedPacket(t *testing.T) {
	type S struct {
		I uint32 `sshtype:"43"`
		S string
		B bool
	}

	s := S{11, "hello", true}
	packet := Marshal(s)
	packet[0] = 42
	roundtrip := S{}
	err := Unmarshal(packet, &roundtrip)
	if err == nil {
		t.Fatal("expected error, not nil")
	}
}

func TestMarshalPtr(t *testing.T) {
	s := struct {
		S string
	}{"hello"}

	m1 := Marshal(s)
	m2 := Marshal(&s)
	if !bytes.Equal(m1, m2) {
		t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
	}
}

func TestBareMarshalUnmarshal(t *testing.T) {
	type S struct {
		I uint32
		S string
		B bool
	}

	s := S{42, "hello", true}
	packet := Marshal(s)
	roundtrip := S{}
	Unmarshal(packet, &roundtrip)

	if !reflect.DeepEqual(s, roundtrip) {
		t.Errorf("got %#v, want %#v", roundtrip, s)
	}
}

func TestBareMarshal(t *testing.T) {
	type S2 struct {
		I uint32
	}
	s := S2{42}
	packet := Marshal(s)
	i, rest, ok := parseUint32(packet)
	if len(rest) > 0 || !ok {
		t.Errorf("parseInt(%q): parse error", packet)
	}
	if i != s.I {
		t.Errorf("got %d, want %d", i, s.I)
	}
}

func TestUnmarshalShortKexInitPacket(t *testing.T) {
	// This used to panic.
	// Issue 11348
	packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
	kim := &kexInitMsg{}
	if err := Unmarshal(packet, kim); err == nil {
		t.Error("truncated packet unmarshaled without error")
	}
}

func TestMarshalMultiTag(t *testing.T) {
	var res struct {
		A uint32 `sshtype:"1|2"`
	}

	good1 := struct {
		A uint32 `sshtype:"1"`
	}{
		1,
	}
	good2 := struct {
		A uint32 `sshtype:"2"`
	}{
		1,
	}

	if e := Unmarshal(Marshal(good1), &res); e != nil {
		t.Errorf("error unmarshaling multipart tag: %v", e)
	}

	if e := Unmarshal(Marshal(good2), &res); e != nil {
		t.Errorf("error unmarshaling multipart tag: %v", e)
	}

	bad1 := struct {
		A uint32 `sshtype:"3"`
	}{
		1,
	}
	if e := Unmarshal(Marshal(bad1), &res); e == nil {
		t.Errorf("bad struct unmarshaled without error")
	}
}

func randomBytes(out []byte, rand *rand.Rand) {
	for i := 0; i < len(out); i++ {
		out[i] = byte(rand.Int31())
	}
}

func randomNameList(rand *rand.Rand) []string {
	ret := make([]string, rand.Int31()&15)
	for i := range ret {
		s := make([]byte, 1+(rand.Int31()&15))
		for j := range s {
			s[j] = 'a' + uint8(rand.Int31()&15)
		}
		ret[i] = string(s)
	}
	return ret
}

func randomInt(rand *rand.Rand) *big.Int {
	return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
}

func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	ki := &kexInitMsg{}
	randomBytes(ki.Cookie[:], rand)
	ki.KexAlgos = randomNameList(rand)
	ki.ServerHostKeyAlgos = randomNameList(rand)
	ki.CiphersClientServer = randomNameList(rand)
	ki.CiphersServerClient = randomNameList(rand)
	ki.MACsClientServer = randomNameList(rand)
	ki.MACsServerClient = randomNameList(rand)
	ki.CompressionClientServer = randomNameList(rand)
	ki.CompressionServerClient = randomNameList(rand)
	ki.LanguagesClientServer = randomNameList(rand)
	ki.LanguagesServerClient = randomNameList(rand)
	if rand.Int31()&1 == 1 {
		ki.FirstKexFollows = true
	}
	return reflect.ValueOf(ki)
}

func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	dhi := &kexDHInitMsg{}
	dhi.X = randomInt(rand)
	return reflect.ValueOf(dhi)
}

var (
	_kexInitMsg   = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
	_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()

	_kexInit   = Marshal(_kexInitMsg)
	_kexDHInit = Marshal(_kexDHInitMsg)
)

func BenchmarkMarshalKexInitMsg(b *testing.B) {
	for i := 0; i < b.N; i++ {
		Marshal(_kexInitMsg)
	}
}

func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
	m := new(kexInitMsg)
	for i := 0; i < b.N; i++ {
		Unmarshal(_kexInit, m)
	}
}

func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
	for i := 0; i < b.N; i++ {
		Marshal(_kexDHInitMsg)
	}
}

func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
	m := new(kexDHInitMsg)
	for i := 0; i < b.N; i++ {
		Unmarshal(_kexDHInit, m)
	}
}