package ldap import ( "bytes" "errors" "io" "net" "net/http" "net/http/httptest" "runtime" "sync" "testing" "time" "gopkg.in/asn1-ber.v1" ) func TestUnresponsiveConnection(t *testing.T) { // The do-nothing server that accepts requests and does nothing ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) defer ts.Close() c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String()) if err != nil { t.Fatalf("error connecting to localhost tcp: %v", err) } // Create an Ldap connection conn := NewConn(c, false) conn.SetTimeout(time.Millisecond) conn.Start() defer conn.Close() // Mock a packet packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID")) bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) packet.AppendChild(bindRequest) // Send packet and test response msgCtx, err := conn.sendMessage(packet) if err != nil { t.Fatalf("error sending message: %v", err) } defer conn.finishMessage(msgCtx) packetResponse, ok := <-msgCtx.responses if !ok { t.Fatalf("no PacketResponse in response channel") } packet, err = packetResponse.ReadPacket() if err == nil { t.Fatalf("expected timeout error") } if err.Error() != "ldap: connection timed out" { t.Fatalf("unexpected error: %v", err) } } // TestFinishMessage tests that we do not enter deadlock when a goroutine makes // a request but does not handle all responses from the server. func TestFinishMessage(t *testing.T) { ptc := newPacketTranslatorConn() defer ptc.Close() conn := NewConn(ptc, false) conn.Start() // Test sending 5 different requests in series. Ensure that we can // get a response packet from the underlying connection and also // ensure that we can gracefully ignore unhandled responses. for i := 0; i < 5; i++ { t.Logf("serial request %d", i) // Create a message and make sure we can receive responses. msgCtx := testSendRequest(t, ptc, conn) testReceiveResponse(t, ptc, msgCtx) // Send a few unhandled responses and finish the message. testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) t.Logf("serial request %d done", i) } // Test sending 5 different requests in parallel. var wg sync.WaitGroup for i := 0; i < 5; i++ { wg.Add(1) go func(i int) { defer wg.Done() t.Logf("parallel request %d", i) // Create a message and make sure we can receive responses. msgCtx := testSendRequest(t, ptc, conn) testReceiveResponse(t, ptc, msgCtx) // Send a few unhandled responses and finish the message. testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) t.Logf("parallel request %d done", i) }(i) } wg.Wait() // We cannot run Close() in a defer because t.FailNow() will run it and // it will block if the processMessage Loop is in a deadlock. conn.Close() } func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) { var msgID int64 runWithTimeout(t, time.Second, func() { msgID = conn.nextMessageID() }) requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID")) var err error runWithTimeout(t, time.Second, func() { msgCtx, err = conn.sendMessage(requestPacket) if err != nil { t.Fatalf("unable to send request message: %s", err) } }) // We should now be able to get this request packet out from the other // side. runWithTimeout(t, time.Second, func() { if _, err = ptc.ReceiveRequest(); err != nil { t.Fatalf("unable to receive request packet: %s", err) } }) return msgCtx } func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) { // Send a mock response packet. responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) runWithTimeout(t, time.Second, func() { if err := ptc.SendResponse(responsePacket); err != nil { t.Fatalf("unable to send response packet: %s", err) } }) // We should be able to receive the packet from the connection. runWithTimeout(t, time.Second, func() { if _, ok := <-msgCtx.responses; !ok { t.Fatal("response channel closed") } }) } func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) { // Send a mock response packet. responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) // Send extra responses but do not attempt to receive them on the // client side. for i := 0; i < numResponses; i++ { runWithTimeout(t, time.Second, func() { if err := ptc.SendResponse(responsePacket); err != nil { t.Fatalf("unable to send response packet: %s", err) } }) } // Finally, attempt to finish this message. runWithTimeout(t, time.Second, func() { conn.finishMessage(msgCtx) }) } func runWithTimeout(t *testing.T, timeout time.Duration, f func()) { done := make(chan struct{}) go func() { f() close(done) }() select { case <-done: // Success! case <-time.After(timeout): _, file, line, _ := runtime.Caller(1) t.Fatalf("%s:%d timed out", file, line) } } // packetTranslatorConn is a helpful type which can be used with various tests // in this package. It implements the net.Conn interface to be used as an // underlying connection for a *ldap.Conn. Most methods are no-ops but the // Read() and Write() methods are able to translate ber-encoded packets for // testing LDAP requests and responses. // // Test cases can simulate an LDAP server sending a response by calling the // SendResponse() method with a ber-encoded LDAP response packet. Test cases // can simulate an LDAP server receiving a request from a client by calling the // ReceiveRequest() method which returns a ber-encoded LDAP request packet. type packetTranslatorConn struct { lock sync.Mutex isClosed bool responseCond sync.Cond requestCond sync.Cond responseBuf bytes.Buffer requestBuf bytes.Buffer } var errPacketTranslatorConnClosed = errors.New("connection closed") func newPacketTranslatorConn() *packetTranslatorConn { conn := &packetTranslatorConn{} conn.responseCond = sync.Cond{L: &conn.lock} conn.requestCond = sync.Cond{L: &conn.lock} return conn } // Read is called by the reader() loop to receive response packets. It will // block until there are more packet bytes available or this connection is // closed. func (c *packetTranslatorConn) Read(b []byte) (n int, err error) { c.lock.Lock() defer c.lock.Unlock() for !c.isClosed { // Attempt to read data from the response buffer. If it fails // with an EOF, wait and try again. n, err = c.responseBuf.Read(b) if err != io.EOF { return n, err } c.responseCond.Wait() } return 0, errPacketTranslatorConnClosed } // SendResponse writes the given response packet to the response buffer for // this connection, signalling any goroutine waiting to read a response. func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error { c.lock.Lock() defer c.lock.Unlock() if c.isClosed { return errPacketTranslatorConnClosed } // Signal any goroutine waiting to read a response. defer c.responseCond.Broadcast() // Writes to the buffer should always succeed. c.responseBuf.Write(packet.Bytes()) return nil } // Write is called by the processMessages() loop to send request packets. func (c *packetTranslatorConn) Write(b []byte) (n int, err error) { c.lock.Lock() defer c.lock.Unlock() if c.isClosed { return 0, errPacketTranslatorConnClosed } // Signal any goroutine waiting to read a request. defer c.requestCond.Broadcast() // Writes to the buffer should always succeed. return c.requestBuf.Write(b) } // ReceiveRequest attempts to read a request packet from this connection. It // will block until it is able to read a full request packet or until this // connection is closed. func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) { c.lock.Lock() defer c.lock.Unlock() for !c.isClosed { // Attempt to parse a request packet from the request buffer. // If it fails with an unexpected EOF, wait and try again. requestReader := bytes.NewReader(c.requestBuf.Bytes()) packet, err := ber.ReadPacket(requestReader) switch err { case io.EOF, io.ErrUnexpectedEOF: c.requestCond.Wait() case nil: // Advance the request buffer by the number of bytes // read to decode the request packet. c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len()) return packet, nil default: return nil, err } } return nil, errPacketTranslatorConnClosed } // Close closes this connection causing Read() and Write() calls to fail. func (c *packetTranslatorConn) Close() error { c.lock.Lock() defer c.lock.Unlock() c.isClosed = true c.responseCond.Broadcast() c.requestCond.Broadcast() return nil } func (c *packetTranslatorConn) LocalAddr() net.Addr { return (*net.TCPAddr)(nil) } func (c *packetTranslatorConn) RemoteAddr() net.Addr { return (*net.TCPAddr)(nil) } func (c *packetTranslatorConn) SetDeadline(t time.Time) error { return nil } func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error { return nil } func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error { return nil }