@@ -288,9 +288,10 @@ func (l *Conn) Close() (err error) {
288
288
l .chanMessage <- & messagePacket {Op : MessageQuit }
289
289
290
290
timeoutCtx := context .Background ()
291
- if l .requestTimeout > 0 {
291
+ requestTimeout := l .getTimeout ()
292
+ if requestTimeout > 0 {
292
293
var cancelFunc context.CancelFunc
293
- timeoutCtx , cancelFunc = context .WithTimeout (timeoutCtx , time .Duration (l . requestTimeout ))
294
+ timeoutCtx , cancelFunc = context .WithTimeout (timeoutCtx , time .Duration (requestTimeout ))
294
295
defer cancelFunc ()
295
296
}
296
297
select {
@@ -316,6 +317,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) {
316
317
atomic .StoreInt64 (& l .requestTimeout , int64 (timeout ))
317
318
}
318
319
320
+ func (l * Conn ) getTimeout () int64 {
321
+ return atomic .LoadInt64 (& l .requestTimeout )
322
+ }
323
+
319
324
// Returns the next available messageID
320
325
func (l * Conn ) nextMessageID () int64 {
321
326
if messageID , ok := <- l .chanMessageID ; ok {
@@ -486,7 +491,7 @@ func (l *Conn) processMessages() {
486
491
// If we are closing due to an error, inform anyone who
487
492
// is waiting about the error.
488
493
if l .IsClosing () && l .closeErr .Load () != nil {
489
- msgCtx .sendResponse (& PacketResponse {Error : l .closeErr .Load ().(error )}, time .Duration (l .requestTimeout ))
494
+ msgCtx .sendResponse (& PacketResponse {Error : l .closeErr .Load ().(error )}, time .Duration (l .getTimeout () ))
490
495
}
491
496
l .Debug .Printf ("Closing channel for MessageID %d" , messageID )
492
497
close (msgCtx .responses )
@@ -514,7 +519,7 @@ func (l *Conn) processMessages() {
514
519
_ , err := l .conn .Write (buf )
515
520
if err != nil {
516
521
l .Debug .Printf ("Error Sending Message: %s" , err .Error ())
517
- message .Context .sendResponse (& PacketResponse {Error : fmt .Errorf ("unable to send request: %s" , err )}, time .Duration (l .requestTimeout ))
522
+ message .Context .sendResponse (& PacketResponse {Error : fmt .Errorf ("unable to send request: %s" , err )}, time .Duration (l .getTimeout () ))
518
523
close (message .Context .responses )
519
524
break
520
525
}
@@ -524,9 +529,9 @@ func (l *Conn) processMessages() {
524
529
l .messageContexts [message .MessageID ] = message .Context
525
530
526
531
// Add timeout if defined
527
- if l .requestTimeout > 0 {
532
+ if l .getTimeout () > 0 {
528
533
go func () {
529
- timer := time .NewTimer (time .Duration (l .requestTimeout ))
534
+ timer := time .NewTimer (time .Duration (l .getTimeout () ))
530
535
defer func () {
531
536
if err := recover (); err != nil {
532
537
l .err = fmt .Errorf ("ldap: recovered panic in RequestTimeout: %v" , err )
@@ -549,7 +554,7 @@ func (l *Conn) processMessages() {
549
554
case MessageResponse :
550
555
l .Debug .Printf ("Receiving message %d" , message .MessageID )
551
556
if msgCtx , ok := l .messageContexts [message .MessageID ]; ok {
552
- msgCtx .sendResponse (& PacketResponse {message .Packet , nil }, time .Duration (l .requestTimeout ))
557
+ msgCtx .sendResponse (& PacketResponse {message .Packet , nil }, time .Duration (l .getTimeout () ))
553
558
} else {
554
559
l .err = fmt .Errorf ("ldap: received unexpected message %d, %v" , message .MessageID , l .IsClosing ())
555
560
l .Debug .PrintPacket (message .Packet )
@@ -559,7 +564,7 @@ func (l *Conn) processMessages() {
559
564
// All reads will return immediately
560
565
if msgCtx , ok := l .messageContexts [message .MessageID ]; ok {
561
566
l .Debug .Printf ("Receiving message timeout for %d" , message .MessageID )
562
- msgCtx .sendResponse (& PacketResponse {message .Packet , NewError (ErrorNetwork , errors .New ("ldap: connection timed out" ))}, time .Duration (l .requestTimeout ))
567
+ msgCtx .sendResponse (& PacketResponse {message .Packet , NewError (ErrorNetwork , errors .New ("ldap: connection timed out" ))}, time .Duration (l .getTimeout () ))
563
568
delete (l .messageContexts , message .MessageID )
564
569
close (msgCtx .responses )
565
570
}
0 commit comments