Skip to content

Commit 4b1c188

Browse files
authored
feat: add common protocol design (#724)
* feat: add common protocol design * fix: remove redundant vars * fix: use AppDesign's ctx * refactor: relay, add AppDesign * feat: changes for suggestions * test: commonService start/stop execution * fix: lint error * nit: add comments
1 parent 2aea2f5 commit 4b1c188

File tree

10 files changed

+215
-206
lines changed

10 files changed

+215
-206
lines changed

waku/v2/protocol/common_service.go

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package protocol
2+
3+
import (
4+
"context"
5+
"errors"
6+
"sync"
7+
)
8+
9+
// this is common layout for all the services that require mutex protection and a guarantee that all running goroutines will be finished before stop finishes execution. This guarantee comes from waitGroup all one has to use CommonService.WaitGroup() in the goroutines that should finish by the end of stop function.
10+
type CommonService struct {
11+
sync.RWMutex
12+
cancel context.CancelFunc
13+
ctx context.Context
14+
wg sync.WaitGroup
15+
started bool
16+
}
17+
18+
func NewCommonService() *CommonService {
19+
return &CommonService{
20+
wg: sync.WaitGroup{},
21+
RWMutex: sync.RWMutex{},
22+
}
23+
}
24+
25+
// mutex protected start function
26+
// creates internal context over provided context and runs fn safely
27+
// fn is excerpt to be executed to start the protocol
28+
func (sp *CommonService) Start(ctx context.Context, fn func() error) error {
29+
sp.Lock()
30+
defer sp.Unlock()
31+
if sp.started {
32+
return ErrAlreadyStarted
33+
}
34+
sp.started = true
35+
sp.ctx, sp.cancel = context.WithCancel(ctx)
36+
if err := fn(); err != nil {
37+
sp.started = false
38+
sp.cancel()
39+
return err
40+
}
41+
return nil
42+
}
43+
44+
var ErrAlreadyStarted = errors.New("already started")
45+
var ErrNotStarted = errors.New("not started")
46+
47+
// mutex protected stop function
48+
func (sp *CommonService) Stop(fn func()) {
49+
sp.Lock()
50+
defer sp.Unlock()
51+
if !sp.started {
52+
return
53+
}
54+
sp.cancel()
55+
fn()
56+
sp.wg.Wait()
57+
sp.started = false
58+
}
59+
60+
// This is not a mutex protected function, it is up to the caller to use it in a mutex protected context
61+
func (sp *CommonService) ErrOnNotRunning() error {
62+
if !sp.started {
63+
return ErrNotStarted
64+
}
65+
return nil
66+
}
67+
68+
func (sp *CommonService) Context() context.Context {
69+
return sp.ctx
70+
}
71+
func (sp *CommonService) WaitGroup() *sync.WaitGroup {
72+
return &sp.wg
73+
}
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package protocol
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
)
8+
9+
// check if start and stop on common service works in random order
10+
func TestCommonService(t *testing.T) {
11+
s := NewCommonService()
12+
wg := &sync.WaitGroup{}
13+
for i := 0; i < 1000; i++ {
14+
wg.Add(1)
15+
if i%2 == 0 {
16+
go func() {
17+
wg.Done()
18+
_ = s.Start(context.TODO(), func() error { return nil })
19+
}()
20+
} else {
21+
go func() {
22+
wg.Done()
23+
go s.Stop(func() {})
24+
}()
25+
}
26+
}
27+
wg.Wait()
28+
}

waku/v2/protocol/filter/client.go

+36-79
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88
"math"
99
"net/http"
10-
"sync"
1110

1211
"github.com/libp2p/go-libp2p/core/host"
1312
"github.com/libp2p/go-libp2p/core/network"
@@ -34,16 +33,11 @@ var (
3433
)
3534

3635
type WakuFilterLightNode struct {
37-
sync.RWMutex
38-
started bool
39-
40-
cancel context.CancelFunc
41-
ctx context.Context
36+
*protocol.CommonService
4237
h host.Host
4338
broadcaster relay.Broadcaster //TODO: Move the broadcast functionality outside of relay client to a higher SDK layer.s
4439
timesource timesource.Timesource
4540
metrics Metrics
46-
wg *sync.WaitGroup
4741
log *zap.Logger
4842
subscriptions *SubscriptionsMap
4943
pm *peermanager.PeerManager
@@ -59,9 +53,6 @@ type WakuFilterPushResult struct {
5953
PeerID peer.ID
6054
}
6155

62-
var errNotStarted = errors.New("not started")
63-
var errAlreadyStarted = errors.New("already started")
64-
6556
// NewWakuFilterLightnode returns a new instance of Waku Filter struct setup according to the chosen parameter and options
6657
// Note that broadcaster is optional.
6758
// Takes an optional peermanager if WakuFilterLightnode is being created along with WakuNode.
@@ -72,8 +63,8 @@ func NewWakuFilterLightNode(broadcaster relay.Broadcaster, pm *peermanager.PeerM
7263
wf.log = log.Named("filterv2-lightnode")
7364
wf.broadcaster = broadcaster
7465
wf.timesource = timesource
75-
wf.wg = &sync.WaitGroup{}
7666
wf.pm = pm
67+
wf.CommonService = protocol.NewCommonService()
7768
wf.metrics = newMetrics(reg)
7869

7970
return wf
@@ -85,59 +76,36 @@ func (wf *WakuFilterLightNode) SetHost(h host.Host) {
8576
}
8677

8778
func (wf *WakuFilterLightNode) Start(ctx context.Context) error {
88-
wf.Lock()
89-
defer wf.Unlock()
79+
return wf.CommonService.Start(ctx, wf.start)
9080

91-
if wf.started {
92-
return errAlreadyStarted
93-
}
94-
95-
wf.wg.Wait() // Wait for any goroutines to stop
81+
}
9682

97-
ctx, cancel := context.WithCancel(ctx)
98-
wf.cancel = cancel
99-
wf.ctx = ctx
83+
func (wf *WakuFilterLightNode) start() error {
10084
wf.subscriptions = NewSubscriptionMap(wf.log)
101-
wf.started = true
102-
103-
wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(ctx))
85+
wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(wf.Context()))
10486

10587
wf.log.Info("filter-push protocol started")
106-
10788
return nil
10889
}
10990

11091
// Stop unmounts the filter protocol
11192
func (wf *WakuFilterLightNode) Stop() {
112-
wf.Lock()
113-
defer wf.Unlock()
114-
115-
if !wf.started {
116-
return
117-
}
118-
119-
wf.cancel()
120-
121-
wf.h.RemoveStreamHandler(FilterPushID_v20beta1)
122-
123-
res, err := wf.unsubscribeAll(wf.ctx)
124-
if err != nil {
125-
wf.log.Warn("unsubscribing from full nodes", zap.Error(err))
126-
}
127-
128-
for r := range res {
129-
if r.Err != nil {
130-
wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID))
93+
wf.CommonService.Stop(func() {
94+
wf.h.RemoveStreamHandler(FilterPushID_v20beta1)
95+
res, err := wf.unsubscribeAll(wf.Context())
96+
if err != nil {
97+
wf.log.Warn("unsubscribing from full nodes", zap.Error(err))
13198
}
13299

133-
}
134-
135-
wf.subscriptions.Clear()
136-
137-
wf.started = false
138-
wf.cancel = nil
100+
for r := range res {
101+
if r.Err != nil {
102+
wf.log.Warn("unsubscribing from full nodes", zap.Error(r.Err), logging.HostID("peerID", r.PeerID))
103+
}
139104

140-
wf.wg.Wait()
105+
}
106+
//
107+
wf.subscriptions.Clear()
108+
})
141109
}
142110

143111
func (wf *WakuFilterLightNode) onRequest(ctx context.Context) func(s network.Stream) {
@@ -248,9 +216,8 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr
248216
func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterSubscribeOption) (*SubscriptionDetails, error) {
249217
wf.RLock()
250218
defer wf.RUnlock()
251-
252-
if !wf.started {
253-
return nil, errNotStarted
219+
if err := wf.ErrOnNotRunning(); err != nil {
220+
return nil, err
254221
}
255222

256223
if contentFilter.Topic == "" {
@@ -285,17 +252,15 @@ func (wf *WakuFilterLightNode) Subscribe(ctx context.Context, contentFilter Cont
285252
if err != nil {
286253
return nil, err
287254
}
288-
289255
return wf.subscriptions.NewSubscription(params.selectedPeer, contentFilter.Topic, contentFilter.ContentTopics), nil
290256
}
291257

292258
// FilterSubscription is used to obtain an object from which you could receive messages received via filter protocol
293259
func (wf *WakuFilterLightNode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) {
294260
wf.RLock()
295261
defer wf.RUnlock()
296-
297-
if !wf.started {
298-
return nil, errNotStarted
262+
if err := wf.ErrOnNotRunning(); err != nil {
263+
return nil, err
299264
}
300265

301266
if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics...) {
@@ -319,9 +284,8 @@ func (wf *WakuFilterLightNode) getUnsubscribeParameters(opts ...FilterUnsubscrib
319284
func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error {
320285
wf.RLock()
321286
defer wf.RUnlock()
322-
323-
if !wf.started {
324-
return errNotStarted
287+
if err := wf.ErrOnNotRunning(); err != nil {
288+
return err
325289
}
326290

327291
return wf.request(
@@ -334,9 +298,8 @@ func (wf *WakuFilterLightNode) Ping(ctx context.Context, peerID peer.ID) error {
334298
func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error {
335299
wf.RLock()
336300
defer wf.RUnlock()
337-
338-
if !wf.started {
339-
return errNotStarted
301+
if err := wf.ErrOnNotRunning(); err != nil {
302+
return err
340303
}
341304

342305
return wf.Ping(ctx, subscription.PeerID)
@@ -345,8 +308,7 @@ func (wf *WakuFilterLightNode) IsSubscriptionAlive(ctx context.Context, subscrip
345308
func (wf *WakuFilterLightNode) Subscriptions() []*SubscriptionDetails {
346309
wf.RLock()
347310
defer wf.RUnlock()
348-
349-
if !wf.started {
311+
if err := wf.ErrOnNotRunning(); err != nil {
350312
return nil
351313
}
352314

@@ -398,13 +360,11 @@ func (wf *WakuFilterLightNode) cleanupSubscriptions(peerID peer.ID, contentFilte
398360
}
399361

400362
// Unsubscribe is used to stop receiving messages from a peer that match a content filter
401-
func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter,
402-
opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
363+
func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
403364
wf.RLock()
404365
defer wf.RUnlock()
405-
406-
if !wf.started {
407-
return nil, errNotStarted
366+
if err := wf.ErrOnNotRunning(); err != nil {
367+
return nil, err
408368
}
409369

410370
if contentFilter.Topic == "" {
@@ -485,13 +445,11 @@ func (wf *WakuFilterLightNode) Unsubscribe(ctx context.Context, contentFilter Co
485445
}
486446

487447
// Unsubscribe is used to stop receiving messages from a peer that match a content filter
488-
func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails,
489-
opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
448+
func (wf *WakuFilterLightNode) UnsubscribeWithSubscription(ctx context.Context, sub *SubscriptionDetails, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
490449
wf.RLock()
491450
defer wf.RUnlock()
492-
493-
if !wf.started {
494-
return nil, errNotStarted
451+
if err := wf.ErrOnNotRunning(); err != nil {
452+
return nil, err
495453
}
496454

497455
var contentTopics []string
@@ -563,9 +521,8 @@ func (wf *WakuFilterLightNode) unsubscribeAll(ctx context.Context, opts ...Filte
563521
func (wf *WakuFilterLightNode) UnsubscribeAll(ctx context.Context, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
564522
wf.RLock()
565523
defer wf.RUnlock()
566-
567-
if !wf.started {
568-
return nil, errNotStarted
524+
if err := wf.ErrOnNotRunning(); err != nil {
525+
return nil, err
569526
}
570527

571528
return wf.unsubscribeAll(ctx, opts...)

waku/v2/protocol/filter/filter_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ func (s *FilterTestSuite) TestRunningGuard() {
350350

351351
_, err := s.lightNode.Subscribe(s.ctx, contentFilter, WithPeer(s.fullNodeHost.ID()))
352352

353-
s.Require().ErrorIs(err, errNotStarted)
353+
s.Require().ErrorIs(err, protocol.ErrNotStarted)
354354

355355
err = s.lightNode.Start(s.ctx)
356356
s.Require().NoError(err)
@@ -398,7 +398,7 @@ func (s *FilterTestSuite) TestStartStop() {
398398
startNode := func() {
399399
for i := 0; i < 100; i++ {
400400
err := s.lightNode.Start(context.Background())
401-
if errors.Is(err, errAlreadyStarted) {
401+
if errors.Is(err, protocol.ErrAlreadyStarted) {
402402
continue
403403
}
404404
s.Require().NoError(err)

0 commit comments

Comments
 (0)