Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to Reject requests by default #58

Merged
merged 3 commits into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ type OnRequestReceivedHook func(p peer.ID, request RequestData, hookActions Requ
// If it returns an error processing is halted and the original request is cancelled.
type OnResponseReceivedHook func(p peer.ID, responseData ResponseData) error

// UnregisterHookFunc is a function call to unregister a hook that was previously registered
type UnregisterHookFunc func()

// GraphExchange is a protocol that can exchange IPLD graphs based on a selector
type GraphExchange interface {
// Request initiates a new GraphSync request to the given peer using the given selector spec.
Expand All @@ -163,8 +166,8 @@ type GraphExchange interface {
// If overrideDefaultValidation is set to true, then if the hook does not error,
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
RegisterRequestReceivedHook(hook OnRequestReceivedHook) error
RegisterRequestReceivedHook(hook OnRequestReceivedHook) UnregisterHookFunc

// RegisterResponseReceivedHook adds a hook that runs when a response is received
RegisterResponseReceivedHook(OnResponseReceivedHook) error
RegisterResponseReceivedHook(OnResponseReceivedHook) UnregisterHookFunc
}
81 changes: 50 additions & 31 deletions impl/graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import (
"context"

"github.com/ipfs/go-graphsync"
"github.com/ipfs/go-graphsync/requestmanager/asyncloader"

gsmsg "github.com/ipfs/go-graphsync/message"
"github.com/ipfs/go-graphsync/messagequeue"
gsnet "github.com/ipfs/go-graphsync/network"
"github.com/ipfs/go-graphsync/peermanager"
"github.com/ipfs/go-graphsync/requestmanager"
"github.com/ipfs/go-graphsync/requestmanager/asyncloader"
"github.com/ipfs/go-graphsync/responsemanager"
"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
"github.com/ipfs/go-graphsync/selectorvalidator"
logging "github.com/ipfs/go-log"
"github.com/ipfs/go-peertaskqueue"
ipld "github.com/ipld/go-ipld-prime"
Expand All @@ -21,26 +21,41 @@ import (

var log = logging.Logger("graphsync")

const maxRecursionDepth = 100

// GraphSync is an instance of a GraphSync exchange that implements
// the graphsync protocol.
type GraphSync struct {
network gsnet.GraphSyncNetwork
loader ipld.Loader
storer ipld.Storer
requestManager *requestmanager.RequestManager
responseManager *responsemanager.ResponseManager
asyncLoader *asyncloader.AsyncLoader
peerResponseManager *peerresponsemanager.PeerResponseManager
peerTaskQueue *peertaskqueue.PeerTaskQueue
peerManager *peermanager.PeerMessageManager
ctx context.Context
cancel context.CancelFunc
network gsnet.GraphSyncNetwork
loader ipld.Loader
storer ipld.Storer
requestManager *requestmanager.RequestManager
responseManager *responsemanager.ResponseManager
asyncLoader *asyncloader.AsyncLoader
peerResponseManager *peerresponsemanager.PeerResponseManager
peerTaskQueue *peertaskqueue.PeerTaskQueue
peerManager *peermanager.PeerMessageManager
ctx context.Context
cancel context.CancelFunc
unregisterDefaultValidator graphsync.UnregisterHookFunc
}

// Option defines the functional option type that can be used to configure
// graphsync instances
type Option func(*GraphSync)

// RejectAllRequestsByDefault means that without hooks registered
// that perform their own request validation, all requests are rejected
func RejectAllRequestsByDefault() Option {
return func(gs *GraphSync) {
gs.unregisterDefaultValidator()
}
}

// New creates a new GraphSync Exchange on the given network,
// and the given link loader+storer.
func New(parent context.Context, network gsnet.GraphSyncNetwork,
loader ipld.Loader, storer ipld.Storer) graphsync.GraphExchange {
loader ipld.Loader, storer ipld.Storer, options ...Option) graphsync.GraphExchange {
ctx, cancel := context.WithCancel(parent)

createMessageQueue := func(ctx context.Context, p peer.ID) peermanager.PeerQueue {
Expand All @@ -55,18 +70,24 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,
}
peerResponseManager := peerresponsemanager.New(ctx, createdResponseQueue)
responseManager := responsemanager.New(ctx, loader, peerResponseManager, peerTaskQueue)
unregisterDefaultValidator := responseManager.RegisterHook(selectorvalidator.SelectorValidator(maxRecursionDepth))
graphSync := &GraphSync{
network: network,
loader: loader,
storer: storer,
asyncLoader: asyncLoader,
requestManager: requestManager,
peerManager: peerManager,
peerTaskQueue: peerTaskQueue,
peerResponseManager: peerResponseManager,
responseManager: responseManager,
ctx: ctx,
cancel: cancel,
network: network,
loader: loader,
storer: storer,
asyncLoader: asyncLoader,
requestManager: requestManager,
peerManager: peerManager,
peerTaskQueue: peerTaskQueue,
peerResponseManager: peerResponseManager,
responseManager: responseManager,
ctx: ctx,
cancel: cancel,
unregisterDefaultValidator: unregisterDefaultValidator,
}

for _, option := range options {
option(graphSync)
}

asyncLoader.Startup()
Expand All @@ -86,15 +107,13 @@ func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, sel
// If overrideDefaultValidation is set to true, then if the hook does not error,
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) error {
gs.responseManager.RegisterHook(hook)
return nil
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) graphsync.UnregisterHookFunc {
return gs.responseManager.RegisterHook(hook)
}

// RegisterResponseReceivedHook adds a hook that runs when a response is received
func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) error {
gs.requestManager.RegisterHook(hook)
return nil
func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) graphsync.UnregisterHookFunc {
return gs.requestManager.RegisterHook(hook)
}

type graphSyncReceiver GraphSync
Expand Down
45 changes: 31 additions & 14 deletions impl/graphsync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,14 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
var receivedRequestData []byte
// initialize graphsync on second node to response to requests
gsnet := td.GraphSyncHost2()
err := gsnet.RegisterRequestReceivedHook(
gsnet.RegisterRequestReceivedHook(
func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
var has bool
receivedRequestData, has = requestData.Extension(td.extensionName)
require.True(t, has, "did not have expected extension")
hookActions.SendExtensionData(td.extensionResponse)
},
)
require.NoError(t, err, "error registering extension")

blockChainLength := 100
blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, 100, blockChainLength)
Expand All @@ -117,7 +116,7 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
message := gsmsg.New()
message.AddRequest(gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), td.extension))
// send request across network
err = td.gsnet1.SendMessage(ctx, td.host2.ID(), message)
err := td.gsnet1.SendMessage(ctx, td.host2.ID(), message)
require.NoError(t, err)
// read the values sent back to requestor
var received gsmsg.GraphSyncMessage
Expand Down Expand Up @@ -150,6 +149,27 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
require.Equal(t, td.extensionResponseData, receivedExtensions[0], "did not return correct extension data")
}

func TestRejectRequestsByDefault(t *testing.T) {
// create network
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
td := newGsTestData(ctx, t)

requestor := td.GraphSyncHost1()
// setup responder to disable default validation, meaning all requests are rejected
_ = td.GraphSyncHost2(RejectAllRequestsByDefault())

blockChainLength := 5
blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, 5, blockChainLength)

// send request across network
progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension)

testutil.VerifyEmptyResponse(ctx, t, progressChan)
testutil.VerifySingleTerminalError(ctx, t, errChan)
}

func TestGraphsyncRoundTrip(t *testing.T) {
// create network
ctx := context.Background()
Expand All @@ -170,17 +190,16 @@ func TestGraphsyncRoundTrip(t *testing.T) {
var receivedResponseData []byte
var receivedRequestData []byte

err := requestor.RegisterResponseReceivedHook(
requestor.RegisterResponseReceivedHook(
func(p peer.ID, responseData graphsync.ResponseData) error {
data, has := responseData.Extension(td.extensionName)
if has {
receivedResponseData = data
}
return nil
})
require.NoError(t, err, "Error setting up extension")

err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
var has bool
receivedRequestData, has = requestData.Extension(td.extensionName)
if !has {
Expand All @@ -189,7 +208,6 @@ func TestGraphsyncRoundTrip(t *testing.T) {
hookActions.SendExtensionData(td.extensionResponse)
}
})
require.NoError(t, err, "Error setting up extension")

progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension)

Expand Down Expand Up @@ -342,15 +360,14 @@ func TestUnixFSFetch(t *testing.T) {
requestor := New(ctx, td.gsnet1, loader1, storer1)
responder := New(ctx, td.gsnet2, loader2, storer2)
extensionName := graphsync.ExtensionName("Free for all")
err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.ValidateRequest()
hookActions.SendExtensionData(graphsync.ExtensionData{
Name: extensionName,
Data: nil,
})
})
require.NoError(t, err)


// make a go-ipld-prime link for the root UnixFS node
clink := cidlink.Link{Cid: nd.Cid()}

Expand Down Expand Up @@ -443,13 +460,13 @@ func newGsTestData(ctx context.Context, t *testing.T) *gsTestData {
return td
}

func (td *gsTestData) GraphSyncHost1() graphsync.GraphExchange {
return New(td.ctx, td.gsnet1, td.loader1, td.storer1)
func (td *gsTestData) GraphSyncHost1(options ...Option) graphsync.GraphExchange {
return New(td.ctx, td.gsnet1, td.loader1, td.storer1, options...)
}

func (td *gsTestData) GraphSyncHost2() graphsync.GraphExchange {
func (td *gsTestData) GraphSyncHost2(options ...Option) graphsync.GraphExchange {

return New(td.ctx, td.gsnet2, td.loader2, td.storer2)
return New(td.ctx, td.gsnet2, td.loader2, td.storer2, options...)
}

type receivedMessage struct {
Expand Down
36 changes: 32 additions & 4 deletions requestmanager/requestmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type inProgressRequestStatus struct {
}

type responseHook struct {
key uint64
hook graphsync.OnResponseReceivedHook
}

Expand Down Expand Up @@ -65,6 +66,7 @@ type RequestManager struct {
// dont touch out side of run loop
nextRequestID graphsync.RequestID
inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
responseHookNextKey uint64
responseHooks []responseHook
}

Expand Down Expand Up @@ -201,12 +203,25 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyn
}
}

type registerHookMessage struct {
hook graphsync.OnResponseReceivedHook
unregisterHookChan chan graphsync.UnregisterHookFunc
}

// RegisterHook registers an extension to processincoming responses
func (rm *RequestManager) RegisterHook(
hook graphsync.OnResponseReceivedHook) {
hook graphsync.OnResponseReceivedHook) graphsync.UnregisterHookFunc {
response := make(chan graphsync.UnregisterHookFunc)
select {
case rm.messages <- &registerHookMessage{hook, response}:
case <-rm.ctx.Done():
return nil
}
select {
case rm.messages <- &responseHook{hook}:
case unregister := <-response:
return unregister
case <-rm.ctx.Done():
return nil
}
}

Expand Down Expand Up @@ -285,8 +300,21 @@ func (prm *processResponseMessage) handle(rm *RequestManager) {
rm.processTerminations(filteredResponses)
}

func (rh *responseHook) handle(rm *RequestManager) {
rm.responseHooks = append(rm.responseHooks, *rh)
func (rhm *registerHookMessage) handle(rm *RequestManager) {
rh := responseHook{rm.responseHookNextKey, rhm.hook}
rm.responseHookNextKey++
rm.responseHooks = append(rm.responseHooks, rh)
select {
case rhm.unregisterHookChan <- func() {
for i, matchHook := range rm.responseHooks {
if rh.key == matchHook.key {
rm.responseHooks = append(rm.responseHooks[:i], rm.responseHooks[i+1:]...)
return
}
}
}:
case <-rm.ctx.Done():
}
}

func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
Expand Down
Loading