From a16b08d372d7a93d4851b86e84e521403d0f6a4a Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 3 May 2024 21:29:53 +0530 Subject: [PATCH] Always send 1 event for a connection --- p2p/net/swarm/connectedness_event_emitter.go | 143 +++++++++++++++++++ p2p/net/swarm/dial_worker.go | 2 +- p2p/net/swarm/swarm.go | 109 +++++--------- p2p/net/swarm/swarm_conn.go | 9 +- p2p/net/swarm/swarm_event_test.go | 66 +++++++++ p2p/net/swarm/swarm_listen.go | 3 +- 6 files changed, 253 insertions(+), 79 deletions(-) create mode 100644 p2p/net/swarm/connectedness_event_emitter.go diff --git a/p2p/net/swarm/connectedness_event_emitter.go b/p2p/net/swarm/connectedness_event_emitter.go new file mode 100644 index 0000000000..793134136b --- /dev/null +++ b/p2p/net/swarm/connectedness_event_emitter.go @@ -0,0 +1,143 @@ +package swarm + +import ( + "context" + "sync" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" +) + +// connectednessEventEmitter emits PeerConnectednessChanged events. +// We ensure that for any peer we connected to we always sent atleast 1 NotConnected Event after +// the peer disconnects. This is because peers can observe a connection before they are notified +// of the connection by a peer connectedness changed event. +type connectednessEventEmitter struct { + mx sync.RWMutex + // newConns is the channel that holds the peerIDs we recently connected to + newConns chan peer.ID + removeConnsMx sync.Mutex + // removeConns is a slice of peerIDs we have recently closed connections to + removeConns []peer.ID + // lastEvent is the last connectedness event sent for a particular peer. + lastEvent map[peer.ID]network.Connectedness + // connectedness is the function that gives the peers current connectedness state + connectedness func(peer.ID) network.Connectedness + // emitter is the PeerConnectednessChanged event emitter + emitter event.Emitter + wg sync.WaitGroup + removeConnNotif chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +func newConnectednessEventEmitter(connectedness func(peer.ID) network.Connectedness, emitter event.Emitter) *connectednessEventEmitter { + ctx, cancel := context.WithCancel(context.Background()) + c := &connectednessEventEmitter{ + newConns: make(chan peer.ID, 32), + lastEvent: make(map[peer.ID]network.Connectedness), + removeConnNotif: make(chan struct{}, 1), + connectedness: connectedness, + emitter: emitter, + ctx: ctx, + cancel: cancel, + } + c.wg.Add(1) + go c.runEmitter() + return c +} + +func (c *connectednessEventEmitter) AddConn(p peer.ID) { + c.mx.RLock() + defer c.mx.RUnlock() + if c.ctx.Err() != nil { + return + } + + c.newConns <- p +} + +func (c *connectednessEventEmitter) RemoveConn(p peer.ID) { + c.mx.RLock() + defer c.mx.RUnlock() + if c.ctx.Err() != nil { + return + } + + c.removeConnsMx.Lock() + // This queue is not unbounded since we block in the AddConn method + // So we are adding connections to the swarm only at a rate + // the subscriber for our peer connectedness changed events can consume them. + // If a lot of open connections are closed at once, increasing the disconnected + // event notification rate, the rate of adding connections to the swarm would + // proportionately reduce, which would eventually reduce the length of this slice. + c.removeConns = append(c.removeConns, p) + c.removeConnsMx.Unlock() + + select { + case c.removeConnNotif <- struct{}{}: + default: + } +} + +func (c *connectednessEventEmitter) Close() { + c.cancel() + c.wg.Wait() +} + +func (c *connectednessEventEmitter) runEmitter() { + defer c.wg.Done() + for { + select { + case p := <-c.newConns: + c.notifyPeer(p, true) + case <-c.removeConnNotif: + c.sendConnRemovedNotifications() + case <-c.ctx.Done(): + c.mx.Lock() // Wait for all pending AddConn & RemoveConn operations to complete + defer c.mx.Unlock() + for { + select { + case p := <-c.newConns: + c.notifyPeer(p, true) + case <-c.removeConnNotif: + c.sendConnRemovedNotifications() + default: + return + } + } + } + } +} + +func (c *connectednessEventEmitter) notifyPeer(p peer.ID, forceNotConnectedEvent bool) { + oldState := c.lastEvent[p] + c.lastEvent[p] = c.connectedness(p) + if c.lastEvent[p] == network.NotConnected { + delete(c.lastEvent, p) + } + if (forceNotConnectedEvent && c.lastEvent[p] == network.NotConnected) || c.lastEvent[p] != oldState { + c.emitter.Emit(event.EvtPeerConnectednessChanged{ + Peer: p, + Connectedness: c.lastEvent[p], + }) + } +} + +func (c *connectednessEventEmitter) sendConnRemovedNotifications() { + c.removeConnsMx.Lock() + defer c.removeConnsMx.Unlock() + for { + if len(c.removeConns) == 0 { + return + } + p := c.removeConns[0] + c.removeConns[0] = "" + c.removeConns = c.removeConns[1:] + + c.removeConnsMx.Unlock() + c.notifyPeer(p, false) + c.removeConnsMx.Lock() + } +} diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 0cac6e4fa3..2ebc4e1efd 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -340,7 +340,7 @@ loop: ad.expectedTCPUpgradeTime = time.Time{} if res.Conn != nil { // we got a connection, add it to the swarm - conn, err := w.s.addConn(res.Conn, network.DirOutbound) + conn, err := w.s.addConn(ad.ctx, res.Conn, network.DirOutbound) if err != nil { // oops no, we failed to add it to the swarm res.Conn.Close() diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 3242bf3076..137fc39510 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -144,9 +144,7 @@ type Swarm struct { // down before continuing. refs sync.WaitGroup - emitter event.Emitter - connectednessEventCh chan struct{} - connectednessEmitterDone chan struct{} + emitter event.Emitter rcmgr network.ResourceManager @@ -158,8 +156,7 @@ type Swarm struct { conns struct { sync.RWMutex - m map[peer.ID][]*Conn - connectednessEvents chan peer.ID + m map[peer.ID][]*Conn } listeners struct { @@ -206,9 +203,10 @@ type Swarm struct { dialRanker network.DialRanker - udpBlackHoleConfig blackHoleConfig - ipv6BlackHoleConfig blackHoleConfig - bhd *blackHoleDetector + udpBlackHoleConfig blackHoleConfig + ipv6BlackHoleConfig blackHoleConfig + bhd *blackHoleDetector + connectednessEventEmitter *connectednessEventEmitter } // NewSwarm constructs a Swarm. @@ -219,17 +217,15 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts } ctx, cancel := context.WithCancel(context.Background()) s := &Swarm{ - local: local, - peers: peers, - emitter: emitter, - connectednessEventCh: make(chan struct{}, 1), - connectednessEmitterDone: make(chan struct{}), - ctx: ctx, - ctxCancel: cancel, - dialTimeout: defaultDialTimeout, - dialTimeoutLocal: defaultDialTimeoutLocal, - maResolver: madns.DefaultResolver, - dialRanker: DefaultDialRanker, + local: local, + peers: peers, + emitter: emitter, + ctx: ctx, + ctxCancel: cancel, + dialTimeout: defaultDialTimeout, + dialTimeoutLocal: defaultDialTimeoutLocal, + maResolver: madns.DefaultResolver, + dialRanker: DefaultDialRanker, // A black hole is a binary property. On a network if UDP dials are blocked or there is // no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials @@ -239,11 +235,11 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts } s.conns.m = make(map[peer.ID][]*Conn) - s.conns.connectednessEvents = make(chan peer.ID, 32) s.listeners.m = make(map[transport.Listener]struct{}) s.transports.m = make(map[int]transport.Transport) s.notifs.m = make(map[network.Notifiee]struct{}) s.directConnNotifs.m = make(map[peer.ID][]chan struct{}) + s.connectednessEventEmitter = newConnectednessEventEmitter(s.Connectedness, emitter) for _, opt := range opts { if err := opt(s); err != nil { @@ -260,7 +256,6 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.backf.init(s.ctx) s.bhd = newBlackHoleDetector(s.udpBlackHoleConfig, s.ipv6BlackHoleConfig, s.metricsTracer) - go s.connectednessEventEmitter() return s, nil } @@ -306,8 +301,7 @@ func (s *Swarm) close() { // Wait for everything to finish. s.refs.Wait() - close(s.conns.connectednessEvents) - <-s.connectednessEmitterDone + s.connectednessEventEmitter.Close() s.emitter.Close() // Now close out any transports (if necessary). Do this after closing @@ -338,7 +332,7 @@ func (s *Swarm) close() { wg.Wait() } -func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) { +func (s *Swarm) addConn(ctx context.Context, tc transport.CapableConn, dir network.Direction) (*Conn, error) { var ( p = tc.RemotePeer() addr = tc.RemoteMultiaddr() @@ -397,18 +391,15 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, // * One will be decremented after the close notifications fire in Conn.doClose // * The other will be decremented when Conn.start exits. s.refs.Add(2) - // Take the notification lock before releasing the conns lock to block // Disconnect notifications until after the Connect notifications done. + // This lock also ensures that swarm.refs.Wait() exits after we have + // enqueued the peer connectedness changed notification. + // TODO: Fix this fragility by taking a swarm ref for dial worker loop c.notifyLk.Lock() s.conns.Unlock() - // Block this goroutine till this request is enqueued. - // This ensures that there are only a finite number of goroutines that are waiting to send - // the connectedness event on the disconnection side in swarm.removeConn. - // This is so because the goroutine to enqueue disconnection event can only be started - // from either a subscriber or a notifier or after calling c.start - s.conns.connectednessEvents <- p + s.connectednessEventEmitter.AddConn(p) if !isLimited { // Notify goroutines waiting for a direct connection @@ -423,7 +414,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, delete(s.directConnNotifs.m, p) s.directConnNotifs.Unlock() } - s.notifyAll(func(f network.Notifiee) { f.Connected(s, c) }) @@ -771,52 +761,21 @@ func (s *Swarm) removeConn(c *Conn) { s.conns.Lock() cs := s.conns.m[p] - if len(cs) == 1 { - delete(s.conns.m, p) - } else { - for i, ci := range cs { - if ci == c { - // NOTE: We're intentionally preserving order. - // This way, connections to a peer are always - // sorted oldest to newest. - copy(cs[i:], cs[i+1:]) - cs[len(cs)-1] = nil - s.conns.m[p] = cs[:len(cs)-1] - break - } + for i, ci := range cs { + if ci == c { + // NOTE: We're intentionally preserving order. + // This way, connections to a peer are always + // sorted oldest to newest. + copy(cs[i:], cs[i+1:]) + cs[len(cs)-1] = nil + s.conns.m[p] = cs[:len(cs)-1] + break } } - s.conns.Unlock() - // Do this in a separate go routine to not block the caller. - // This ensures that if a event subscriber closes the connection from the subscription goroutine - // this doesn't deadlock - s.refs.Add(1) - go func() { - defer s.refs.Done() - s.conns.connectednessEvents <- p - }() -} - -func (s *Swarm) connectednessEventEmitter() { - defer close(s.connectednessEmitterDone) - lastConnectednessEvents := make(map[peer.ID]network.Connectedness) - for p := range s.conns.connectednessEvents { - s.conns.Lock() - oldState := lastConnectednessEvents[p] - newState := s.connectednessUnlocked(p) - if newState != network.NotConnected { - lastConnectednessEvents[p] = newState - } else { - delete(lastConnectednessEvents, p) - } - s.conns.Unlock() - if newState != oldState { - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: newState, - }) - } + if len(s.conns.m[p]) == 0 { + delete(s.conns.m, p) } + s.conns.Unlock() } // String returns a string representation of Network. diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 17ae1dffae..38e942cce8 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -73,6 +73,11 @@ func (c *Conn) doClose() { c.err = c.conn.Close() + // Send the connectedness event after closing the connection. + // This ensures that both remote connection close and local connection + // close events are sent after the underlying transport connection is closed. + c.swarm.connectednessEventEmitter.RemoveConn(c.RemotePeer()) + // This is just for cleaning up state. The connection has already been closed. // We *could* optimize this but it really isn't worth it. for s := range streams { @@ -85,10 +90,11 @@ func (c *Conn) doClose() { c.notifyLk.Lock() defer c.notifyLk.Unlock() + // Only notify for disconnection if we notified for connection c.swarm.notifyAll(func(f network.Notifiee) { f.Disconnected(c.swarm, c) }) - c.swarm.refs.Done() // taken in Swarm.addConn + c.swarm.refs.Done() }() } @@ -108,7 +114,6 @@ func (c *Conn) start() { go func() { defer c.swarm.refs.Done() defer c.Close() - for { ts, err := c.conn.AcceptStream() if err != nil { diff --git a/p2p/net/swarm/swarm_event_test.go b/p2p/net/swarm/swarm_event_test.go index 12393c5b1d..0a06a98fe4 100644 --- a/p2p/net/swarm/swarm_event_test.go +++ b/p2p/net/swarm/swarm_event_test.go @@ -2,6 +2,8 @@ package swarm_test import ( "context" + "fmt" + "sync" "testing" "time" @@ -244,3 +246,67 @@ func TestConnectednessEventDeadlock(t *testing.T) { t.Fatal("expected all connectedness events to be completed") } } + +func TestConnectednessEventDeadlockWithDial(t *testing.T) { + s1, sub1 := newSwarmWithSubscription(t) + const N = 200 + peers := make([]*Swarm, N) + for i := 0; i < N; i++ { + peers[i] = swarmt.GenSwarm(t) + } + peers2 := make([]*Swarm, N) + for i := 0; i < N; i++ { + peers2[i] = swarmt.GenSwarm(t) + } + + // First check all connected events + done := make(chan struct{}) + var subWG sync.WaitGroup + subWG.Add(1) + go func() { + defer subWG.Done() + count := 0 + for { + var e interface{} + select { + case e = <-sub1.Out(): + case <-done: + return + } + // sleep to simulate a slow consumer + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.Connected { + continue + } + if count < N { + time.Sleep(10 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + s1.Peerstore().AddAddrs(peers2[count].LocalPeer(), []ma.Multiaddr{peers2[count].ListenAddresses()[0]}, time.Hour) + s1.DialPeer(ctx, peers2[count].LocalPeer()) + count++ + cancel() + } + } + }() + var wg sync.WaitGroup + wg.Add(N) + for i := 0; i < N; i++ { + s1.Peerstore().AddAddrs(peers[i].LocalPeer(), []ma.Multiaddr{peers[i].ListenAddresses()[0]}, time.Hour) + go func(i int) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + s1.DialPeer(ctx, peers[i].LocalPeer()) + cancel() + wg.Done() + }(i) + } + wg.Wait() + s1.Close() + + close(done) + subWG.Wait() + fmt.Println("swarm closed") +} diff --git a/p2p/net/swarm/swarm_listen.go b/p2p/net/swarm/swarm_listen.go index 0905e84513..2376a7e379 100644 --- a/p2p/net/swarm/swarm_listen.go +++ b/p2p/net/swarm/swarm_listen.go @@ -1,6 +1,7 @@ package swarm import ( + "context" "errors" "fmt" "time" @@ -142,7 +143,7 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { s.refs.Add(1) go func() { defer s.refs.Done() - _, err := s.addConn(c, network.DirInbound) + _, err := s.addConn(context.Background(), c, network.DirInbound) switch err { case nil: case ErrSwarmClosed: