mirror of
https://github.com/Luzifer/nginx-sso.git
synced 2025-01-02 03:01:16 +00:00
337 lines
9.7 KiB
Go
337 lines
9.7 KiB
Go
|
package ldap
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"runtime"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"gopkg.in/asn1-ber.v1"
|
||
|
)
|
||
|
|
||
|
func TestUnresponsiveConnection(t *testing.T) {
|
||
|
// The do-nothing server that accepts requests and does nothing
|
||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
}))
|
||
|
defer ts.Close()
|
||
|
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
|
||
|
if err != nil {
|
||
|
t.Fatalf("error connecting to localhost tcp: %v", err)
|
||
|
}
|
||
|
|
||
|
// Create an Ldap connection
|
||
|
conn := NewConn(c, false)
|
||
|
conn.SetTimeout(time.Millisecond)
|
||
|
conn.Start()
|
||
|
defer conn.Close()
|
||
|
|
||
|
// Mock a packet
|
||
|
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||
|
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID"))
|
||
|
bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||
|
bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||
|
packet.AppendChild(bindRequest)
|
||
|
|
||
|
// Send packet and test response
|
||
|
msgCtx, err := conn.sendMessage(packet)
|
||
|
if err != nil {
|
||
|
t.Fatalf("error sending message: %v", err)
|
||
|
}
|
||
|
defer conn.finishMessage(msgCtx)
|
||
|
|
||
|
packetResponse, ok := <-msgCtx.responses
|
||
|
if !ok {
|
||
|
t.Fatalf("no PacketResponse in response channel")
|
||
|
}
|
||
|
packet, err = packetResponse.ReadPacket()
|
||
|
if err == nil {
|
||
|
t.Fatalf("expected timeout error")
|
||
|
}
|
||
|
if err.Error() != "ldap: connection timed out" {
|
||
|
t.Fatalf("unexpected error: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TestFinishMessage tests that we do not enter deadlock when a goroutine makes
|
||
|
// a request but does not handle all responses from the server.
|
||
|
func TestFinishMessage(t *testing.T) {
|
||
|
ptc := newPacketTranslatorConn()
|
||
|
defer ptc.Close()
|
||
|
|
||
|
conn := NewConn(ptc, false)
|
||
|
conn.Start()
|
||
|
|
||
|
// Test sending 5 different requests in series. Ensure that we can
|
||
|
// get a response packet from the underlying connection and also
|
||
|
// ensure that we can gracefully ignore unhandled responses.
|
||
|
for i := 0; i < 5; i++ {
|
||
|
t.Logf("serial request %d", i)
|
||
|
// Create a message and make sure we can receive responses.
|
||
|
msgCtx := testSendRequest(t, ptc, conn)
|
||
|
testReceiveResponse(t, ptc, msgCtx)
|
||
|
|
||
|
// Send a few unhandled responses and finish the message.
|
||
|
testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
|
||
|
t.Logf("serial request %d done", i)
|
||
|
}
|
||
|
|
||
|
// Test sending 5 different requests in parallel.
|
||
|
var wg sync.WaitGroup
|
||
|
for i := 0; i < 5; i++ {
|
||
|
wg.Add(1)
|
||
|
go func(i int) {
|
||
|
defer wg.Done()
|
||
|
t.Logf("parallel request %d", i)
|
||
|
// Create a message and make sure we can receive responses.
|
||
|
msgCtx := testSendRequest(t, ptc, conn)
|
||
|
testReceiveResponse(t, ptc, msgCtx)
|
||
|
|
||
|
// Send a few unhandled responses and finish the message.
|
||
|
testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
|
||
|
t.Logf("parallel request %d done", i)
|
||
|
}(i)
|
||
|
}
|
||
|
wg.Wait()
|
||
|
|
||
|
// We cannot run Close() in a defer because t.FailNow() will run it and
|
||
|
// it will block if the processMessage Loop is in a deadlock.
|
||
|
conn.Close()
|
||
|
}
|
||
|
|
||
|
func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) {
|
||
|
var msgID int64
|
||
|
runWithTimeout(t, time.Second, func() {
|
||
|
msgID = conn.nextMessageID()
|
||
|
})
|
||
|
|
||
|
requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||
|
requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID"))
|
||
|
|
||
|
var err error
|
||
|
|
||
|
runWithTimeout(t, time.Second, func() {
|
||
|
msgCtx, err = conn.sendMessage(requestPacket)
|
||
|
if err != nil {
|
||
|
t.Fatalf("unable to send request message: %s", err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
// We should now be able to get this request packet out from the other
|
||
|
// side.
|
||
|
runWithTimeout(t, time.Second, func() {
|
||
|
if _, err = ptc.ReceiveRequest(); err != nil {
|
||
|
t.Fatalf("unable to receive request packet: %s", err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
return msgCtx
|
||
|
}
|
||
|
|
||
|
func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) {
|
||
|
// Send a mock response packet.
|
||
|
responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
|
||
|
responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
|
||
|
|
||
|
runWithTimeout(t, time.Second, func() {
|
||
|
if err := ptc.SendResponse(responsePacket); err != nil {
|
||
|
t.Fatalf("unable to send response packet: %s", err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
// We should be able to receive the packet from the connection.
|
||
|
runWithTimeout(t, time.Second, func() {
|
||
|
if _, ok := <-msgCtx.responses; !ok {
|
||
|
t.Fatal("response channel closed")
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) {
|
||
|
// Send a mock response packet.
|
||
|
responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
|
||
|
responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
|
||
|
|
||
|
// Send extra responses but do not attempt to receive them on the
|
||
|
// client side.
|
||
|
for i := 0; i < numResponses; i++ {
|
||
|
runWithTimeout(t, time.Second, func() {
|
||
|
if err := ptc.SendResponse(responsePacket); err != nil {
|
||
|
t.Fatalf("unable to send response packet: %s", err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// Finally, attempt to finish this message.
|
||
|
runWithTimeout(t, time.Second, func() {
|
||
|
conn.finishMessage(msgCtx)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
|
||
|
done := make(chan struct{})
|
||
|
go func() {
|
||
|
f()
|
||
|
close(done)
|
||
|
}()
|
||
|
|
||
|
select {
|
||
|
case <-done: // Success!
|
||
|
case <-time.After(timeout):
|
||
|
_, file, line, _ := runtime.Caller(1)
|
||
|
t.Fatalf("%s:%d timed out", file, line)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// packetTranslatorConn is a helpful type which can be used with various tests
|
||
|
// in this package. It implements the net.Conn interface to be used as an
|
||
|
// underlying connection for a *ldap.Conn. Most methods are no-ops but the
|
||
|
// Read() and Write() methods are able to translate ber-encoded packets for
|
||
|
// testing LDAP requests and responses.
|
||
|
//
|
||
|
// Test cases can simulate an LDAP server sending a response by calling the
|
||
|
// SendResponse() method with a ber-encoded LDAP response packet. Test cases
|
||
|
// can simulate an LDAP server receiving a request from a client by calling the
|
||
|
// ReceiveRequest() method which returns a ber-encoded LDAP request packet.
|
||
|
type packetTranslatorConn struct {
|
||
|
lock sync.Mutex
|
||
|
isClosed bool
|
||
|
|
||
|
responseCond sync.Cond
|
||
|
requestCond sync.Cond
|
||
|
|
||
|
responseBuf bytes.Buffer
|
||
|
requestBuf bytes.Buffer
|
||
|
}
|
||
|
|
||
|
var errPacketTranslatorConnClosed = errors.New("connection closed")
|
||
|
|
||
|
func newPacketTranslatorConn() *packetTranslatorConn {
|
||
|
conn := &packetTranslatorConn{}
|
||
|
conn.responseCond = sync.Cond{L: &conn.lock}
|
||
|
conn.requestCond = sync.Cond{L: &conn.lock}
|
||
|
|
||
|
return conn
|
||
|
}
|
||
|
|
||
|
// Read is called by the reader() loop to receive response packets. It will
|
||
|
// block until there are more packet bytes available or this connection is
|
||
|
// closed.
|
||
|
func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
|
||
|
c.lock.Lock()
|
||
|
defer c.lock.Unlock()
|
||
|
|
||
|
for !c.isClosed {
|
||
|
// Attempt to read data from the response buffer. If it fails
|
||
|
// with an EOF, wait and try again.
|
||
|
n, err = c.responseBuf.Read(b)
|
||
|
if err != io.EOF {
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
c.responseCond.Wait()
|
||
|
}
|
||
|
|
||
|
return 0, errPacketTranslatorConnClosed
|
||
|
}
|
||
|
|
||
|
// SendResponse writes the given response packet to the response buffer for
|
||
|
// this connection, signalling any goroutine waiting to read a response.
|
||
|
func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
|
||
|
c.lock.Lock()
|
||
|
defer c.lock.Unlock()
|
||
|
|
||
|
if c.isClosed {
|
||
|
return errPacketTranslatorConnClosed
|
||
|
}
|
||
|
|
||
|
// Signal any goroutine waiting to read a response.
|
||
|
defer c.responseCond.Broadcast()
|
||
|
|
||
|
// Writes to the buffer should always succeed.
|
||
|
c.responseBuf.Write(packet.Bytes())
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Write is called by the processMessages() loop to send request packets.
|
||
|
func (c *packetTranslatorConn) Write(b []byte) (n int, err error) {
|
||
|
c.lock.Lock()
|
||
|
defer c.lock.Unlock()
|
||
|
|
||
|
if c.isClosed {
|
||
|
return 0, errPacketTranslatorConnClosed
|
||
|
}
|
||
|
|
||
|
// Signal any goroutine waiting to read a request.
|
||
|
defer c.requestCond.Broadcast()
|
||
|
|
||
|
// Writes to the buffer should always succeed.
|
||
|
return c.requestBuf.Write(b)
|
||
|
}
|
||
|
|
||
|
// ReceiveRequest attempts to read a request packet from this connection. It
|
||
|
// will block until it is able to read a full request packet or until this
|
||
|
// connection is closed.
|
||
|
func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) {
|
||
|
c.lock.Lock()
|
||
|
defer c.lock.Unlock()
|
||
|
|
||
|
for !c.isClosed {
|
||
|
// Attempt to parse a request packet from the request buffer.
|
||
|
// If it fails with an unexpected EOF, wait and try again.
|
||
|
requestReader := bytes.NewReader(c.requestBuf.Bytes())
|
||
|
packet, err := ber.ReadPacket(requestReader)
|
||
|
switch err {
|
||
|
case io.EOF, io.ErrUnexpectedEOF:
|
||
|
c.requestCond.Wait()
|
||
|
case nil:
|
||
|
// Advance the request buffer by the number of bytes
|
||
|
// read to decode the request packet.
|
||
|
c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len())
|
||
|
return packet, nil
|
||
|
default:
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil, errPacketTranslatorConnClosed
|
||
|
}
|
||
|
|
||
|
// Close closes this connection causing Read() and Write() calls to fail.
|
||
|
func (c *packetTranslatorConn) Close() error {
|
||
|
c.lock.Lock()
|
||
|
defer c.lock.Unlock()
|
||
|
|
||
|
c.isClosed = true
|
||
|
c.responseCond.Broadcast()
|
||
|
c.requestCond.Broadcast()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *packetTranslatorConn) LocalAddr() net.Addr {
|
||
|
return (*net.TCPAddr)(nil)
|
||
|
}
|
||
|
|
||
|
func (c *packetTranslatorConn) RemoteAddr() net.Addr {
|
||
|
return (*net.TCPAddr)(nil)
|
||
|
}
|
||
|
|
||
|
func (c *packetTranslatorConn) SetDeadline(t time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error {
|
||
|
return nil
|
||
|
}
|