Skip to content

Commit 03cc78c

Browse files
authored
fix: request timeout race condition (#465)
1 parent 81e9beb commit 03cc78c

File tree

4 files changed

+78
-16
lines changed

4 files changed

+78
-16
lines changed

conn.go

+13-8
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,10 @@ func (l *Conn) Close() (err error) {
288288
l.chanMessage <- &messagePacket{Op: MessageQuit}
289289

290290
timeoutCtx := context.Background()
291-
if l.requestTimeout > 0 {
291+
requestTimeout := l.getTimeout()
292+
if requestTimeout > 0 {
292293
var cancelFunc context.CancelFunc
293-
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
294+
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(requestTimeout))
294295
defer cancelFunc()
295296
}
296297
select {
@@ -316,6 +317,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) {
316317
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
317318
}
318319

320+
func (l *Conn) getTimeout() int64 {
321+
return atomic.LoadInt64(&l.requestTimeout)
322+
}
323+
319324
// Returns the next available messageID
320325
func (l *Conn) nextMessageID() int64 {
321326
if messageID, ok := <-l.chanMessageID; ok {
@@ -486,7 +491,7 @@ func (l *Conn) processMessages() {
486491
// If we are closing due to an error, inform anyone who
487492
// is waiting about the error.
488493
if l.IsClosing() && l.closeErr.Load() != nil {
489-
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
494+
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
490495
}
491496
l.Debug.Printf("Closing channel for MessageID %d", messageID)
492497
close(msgCtx.responses)
@@ -514,7 +519,7 @@ func (l *Conn) processMessages() {
514519
_, err := l.conn.Write(buf)
515520
if err != nil {
516521
l.Debug.Printf("Error Sending Message: %s", err.Error())
517-
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
522+
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
518523
close(message.Context.responses)
519524
break
520525
}
@@ -524,9 +529,9 @@ func (l *Conn) processMessages() {
524529
l.messageContexts[message.MessageID] = message.Context
525530

526531
// Add timeout if defined
527-
if l.requestTimeout > 0 {
532+
if l.getTimeout() > 0 {
528533
go func() {
529-
timer := time.NewTimer(time.Duration(l.requestTimeout))
534+
timer := time.NewTimer(time.Duration(l.getTimeout()))
530535
defer func() {
531536
if err := recover(); err != nil {
532537
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
@@ -549,7 +554,7 @@ func (l *Conn) processMessages() {
549554
case MessageResponse:
550555
l.Debug.Printf("Receiving message %d", message.MessageID)
551556
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
552-
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
557+
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
553558
} else {
554559
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
555560
l.Debug.PrintPacket(message.Packet)
@@ -559,7 +564,7 @@ func (l *Conn) processMessages() {
559564
// All reads will return immediately
560565
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
561566
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
562-
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
567+
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
563568
delete(l.messageContexts, message.MessageID)
564569
close(msgCtx.responses)
565570
}

conn_test.go

+26
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,32 @@ func TestInvalidStateCloseDeadlock(t *testing.T) {
7979
conn.Close()
8080
}
8181

82+
func TestRequestTimeoutDeadlock(t *testing.T) {
83+
// The do-nothing server that accepts requests and does nothing
84+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
85+
}))
86+
defer ts.Close()
87+
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
88+
if err != nil {
89+
t.Fatalf("error connecting to localhost tcp: %v", err)
90+
}
91+
92+
// Create an Ldap connection
93+
conn := NewConn(c, false)
94+
conn.Start()
95+
// trigger a race condition on accessing request timeout
96+
n := 3
97+
for i := 0; i < n; i++ {
98+
go func() {
99+
conn.SetTimeout(time.Millisecond)
100+
}()
101+
}
102+
103+
// Attempt to close the connection when the message handler is
104+
// blocked or inactive
105+
conn.Close()
106+
}
107+
82108
// TestInvalidStateSendResponseDeadlock tests that we do not enter deadlock when the
83109
// message handler is blocked or inactive.
84110
func TestInvalidStateSendResponseDeadlock(t *testing.T) {

v3/conn.go

+13-8
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ func (l *Conn) Close() (err error) {
288288
l.chanMessage <- &messagePacket{Op: MessageQuit}
289289

290290
timeoutCtx := context.Background()
291-
if l.requestTimeout > 0 {
291+
if l.getTimeout() > 0 {
292292
var cancelFunc context.CancelFunc
293-
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
293+
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
294294
defer cancelFunc()
295295
}
296296
select {
@@ -316,6 +316,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) {
316316
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
317317
}
318318

319+
func (l *Conn) getTimeout() int64 {
320+
return atomic.LoadInt64(&l.requestTimeout)
321+
}
322+
319323
// Returns the next available messageID
320324
func (l *Conn) nextMessageID() int64 {
321325
if messageID, ok := <-l.chanMessageID; ok {
@@ -486,7 +490,7 @@ func (l *Conn) processMessages() {
486490
// If we are closing due to an error, inform anyone who
487491
// is waiting about the error.
488492
if l.IsClosing() && l.closeErr.Load() != nil {
489-
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
493+
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
490494
}
491495
l.Debug.Printf("Closing channel for MessageID %d", messageID)
492496
close(msgCtx.responses)
@@ -514,7 +518,7 @@ func (l *Conn) processMessages() {
514518
_, err := l.conn.Write(buf)
515519
if err != nil {
516520
l.Debug.Printf("Error Sending Message: %s", err.Error())
517-
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
521+
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
518522
close(message.Context.responses)
519523
break
520524
}
@@ -524,9 +528,10 @@ func (l *Conn) processMessages() {
524528
l.messageContexts[message.MessageID] = message.Context
525529

526530
// Add timeout if defined
527-
if l.requestTimeout > 0 {
531+
requestTimeout := l.getTimeout()
532+
if requestTimeout > 0 {
528533
go func() {
529-
timer := time.NewTimer(time.Duration(l.requestTimeout))
534+
timer := time.NewTimer(time.Duration(requestTimeout))
530535
defer func() {
531536
if err := recover(); err != nil {
532537
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
@@ -549,7 +554,7 @@ func (l *Conn) processMessages() {
549554
case MessageResponse:
550555
l.Debug.Printf("Receiving message %d", message.MessageID)
551556
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
552-
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
557+
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
553558
} else {
554559
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
555560
l.Debug.PrintPacket(message.Packet)
@@ -559,7 +564,7 @@ func (l *Conn) processMessages() {
559564
// All reads will return immediately
560565
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
561566
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
562-
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
567+
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
563568
delete(l.messageContexts, message.MessageID)
564569
close(msgCtx.responses)
565570
}

v3/conn_test.go

+26
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,32 @@ func TestUnresponsiveConnection(t *testing.T) {
5858
}
5959
}
6060

61+
func TestRequestTimeoutDeadlock(t *testing.T) {
62+
// The do-nothing server that accepts requests and does nothing
63+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
64+
}))
65+
defer ts.Close()
66+
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
67+
if err != nil {
68+
t.Fatalf("error connecting to localhost tcp: %v", err)
69+
}
70+
71+
// Create an Ldap connection
72+
conn := NewConn(c, false)
73+
conn.Start()
74+
// trigger a race condition on accessing request timeout
75+
n := 3
76+
for i := 0; i < n; i++ {
77+
go func() {
78+
conn.SetTimeout(time.Millisecond)
79+
}()
80+
}
81+
82+
// Attempt to close the connection when the message handler is
83+
// blocked or inactive
84+
conn.Close()
85+
}
86+
6187
// TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the
6288
// message handler is blocked or inactive.
6389
func TestInvalidStateCloseDeadlock(t *testing.T) {

0 commit comments

Comments
 (0)