Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deadlocks caused by invalid connection state #432

Merged
merged 6 commits into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ldap

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -60,13 +61,21 @@ type messageContext struct {

// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
func (msgCtx *messageContext) sendResponse(packet *PacketResponse, timeout time.Duration) {
timeoutCtx := context.Background()
if timeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
}
select {
case msgCtx.responses <- packet:
// Successfully sent packet to message handler.
case <-msgCtx.done:
// The request handler is done and will not receive more
// packets.
case <-timeoutCtx.Done():
// The timeout was reached before the packet was sent.
}
}

Expand Down Expand Up @@ -238,7 +247,7 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {

// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
l := &Conn{
conn: conn,
chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
Expand All @@ -247,11 +256,12 @@ func NewConn(conn net.Conn, isTLS bool) *Conn {
requestTimeout: 0,
isTLS: isTLS,
}
l.wgClose.Add(1)
return l
}

// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
l.wgClose.Add(1)
go l.reader()
go l.processMessages()
}
Expand All @@ -274,7 +284,20 @@ func (l *Conn) Close() {
if l.setClosing() {
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm

timeoutCtx := context.Background()
if l.requestTimeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
defer cancelFunc()
}
select {
case <-l.chanConfirm:
// Confirmation was received.
case <-timeoutCtx.Done():
// The timeout was reached before confirmation was received.
}

close(l.chanMessage)

l.Debug.Printf("Closing network connection")
Expand Down Expand Up @@ -454,7 +477,7 @@ func (l *Conn) processMessages() {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
Expand Down Expand Up @@ -482,7 +505,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
close(message.Context.responses)
break
}
Expand Down Expand Up @@ -517,7 +540,7 @@ func (l *Conn) processMessages() {
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
} else {
logger.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
Expand All @@ -527,7 +550,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))})
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
Expand Down
33 changes: 33 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,39 @@ func TestUnresponsiveConnection(t *testing.T) {
}
}

// TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateCloseDeadlock(t *testing.T) {
// The do-nothing server that accepts requests and does nothing
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error connecting to localhost tcp: %v", err)
}

// Create an Ldap connection
conn := NewConn(c, false)
conn.SetTimeout(time.Millisecond)

// Attempt to close the connection when the message handler is
// blocked or inactive
conn.Close()
}

// TestInvalidStateSendResponseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateSendResponseDeadlock(t *testing.T) {
// Attempt to send a response packet when the message handler is blocked or inactive
msgCtx := &messageContext{
id: 0,
done: make(chan struct{}),
responses: make(chan *PacketResponse),
}
msgCtx.sendResponse(&PacketResponse{}, time.Millisecond)
}

// TestFinishMessage tests that we do not enter deadlock when a goroutine makes
// a request but does not handle all responses from the server.
func TestFinishMessage(t *testing.T) {
Expand Down
39 changes: 31 additions & 8 deletions v3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ldap

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -60,13 +61,21 @@ type messageContext struct {

// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
func (msgCtx *messageContext) sendResponse(packet *PacketResponse, timeout time.Duration) {
timeoutCtx := context.Background()
if timeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
}
select {
case msgCtx.responses <- packet:
// Successfully sent packet to message handler.
case <-msgCtx.done:
// The request handler is done and will not receive more
// packets.
case <-timeoutCtx.Done():
// The timeout was reached before the packet was sent.
}
}

Expand Down Expand Up @@ -238,7 +247,7 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {

// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
l := &Conn{
conn: conn,
chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
Expand All @@ -247,11 +256,12 @@ func NewConn(conn net.Conn, isTLS bool) *Conn {
requestTimeout: 0,
isTLS: isTLS,
}
l.wgClose.Add(1)
return l
}

// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
l.wgClose.Add(1)
go l.reader()
go l.processMessages()
}
Expand All @@ -274,7 +284,20 @@ func (l *Conn) Close() {
if l.setClosing() {
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm

timeoutCtx := context.Background()
if l.requestTimeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
defer cancelFunc()
}
select {
case <-l.chanConfirm:
// Confirmation was received.
case <-timeoutCtx.Done():
// The timeout was reached before confirmation was received.
}

close(l.chanMessage)

l.Debug.Printf("Closing network connection")
Expand Down Expand Up @@ -454,7 +477,7 @@ func (l *Conn) processMessages() {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
Expand Down Expand Up @@ -482,7 +505,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
close(message.Context.responses)
break
}
Expand Down Expand Up @@ -517,7 +540,7 @@ func (l *Conn) processMessages() {
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
} else {
logger.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
Expand All @@ -527,7 +550,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))})
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
Expand Down
33 changes: 33 additions & 0 deletions v3/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,39 @@ func TestUnresponsiveConnection(t *testing.T) {
}
}

// TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateCloseDeadlock(t *testing.T) {
// The do-nothing server that accepts requests and does nothing
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error connecting to localhost tcp: %v", err)
}

// Create an Ldap connection
conn := NewConn(c, false)
conn.SetTimeout(time.Millisecond)

// Attempt to close the connection when the message handler is
// blocked or inactive
conn.Close()
}

// TestInvalidStateSendResponseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateSendResponseDeadlock(t *testing.T) {
// Attempt to send a response packet when the message handler is blocked or inactive
msgCtx := &messageContext{
id: 0,
done: make(chan struct{}),
responses: make(chan *PacketResponse),
}
msgCtx.sendResponse(&PacketResponse{}, time.Millisecond)
}

// TestFinishMessage tests that we do not enter deadlock when a goroutine makes
// a request but does not handle all responses from the server.
func TestFinishMessage(t *testing.T) {
Expand Down