Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/ckousik/initiator-noise-handshak…
Browse files Browse the repository at this point in the history
…e' into ckousik/webrtc
  • Loading branch information
ckousik committed Oct 18, 2022
2 parents f5bf378 + c172f52 commit 10cb830
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 19 deletions.
2 changes: 1 addition & 1 deletion p2p/security/noise/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
init, resp := net.Pipe()
_ = resp.Close()

session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, nil, true)
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, nil, true, true)
_, err := session.encrypt(nil, []byte("hi"))
if err == nil {
t.Error("expected encryption error when handshake incomplete")
Expand Down
11 changes: 6 additions & 5 deletions p2p/security/noise/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,14 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
return nil, err
}

// check the peer ID for:
// * all outbound connection
// * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID)
if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) {
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
// check the peer ID if enabled
if s.checkPeerID && s.remoteID != id {
return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}
// if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) {
// // use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
// return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
// }

// verify payload is signed by asserted remote libp2p key.
sig := nhp.GetIdentitySig()
Expand Down
6 changes: 4 additions & 2 deletions p2p/security/noise/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (
)

type secureSession struct {
initiator bool
initiator bool
checkPeerID bool

localID peer.ID
localKey crypto.PrivKey
Expand Down Expand Up @@ -44,7 +45,7 @@ type secureSession struct {

// newSecureSession creates a Noise session over the given insecureConn Conn, using
// the libp2p identity keypair from the given Transport.
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiatorEDH, responderEDH EarlyDataHandler, initiator bool) (*secureSession, error) {
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiatorEDH, responderEDH EarlyDataHandler, initiator, checkPeerID bool) (*secureSession, error) {
s := &secureSession{
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
Expand All @@ -55,6 +56,7 @@ func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, re
prologue: prologue,
initiatorEarlyDataHandler: initiatorEDH,
responderEarlyDataHandler: responderEDH,
checkPeerID: checkPeerID,
}

// the go-routine we create to run the handshake will
Expand Down
14 changes: 11 additions & 3 deletions p2p/security/noise/session_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,30 @@ func EarlyData(initiator, responder EarlyDataHandler) SessionOption {
}
}

func CheckPeerID(enable bool) SessionOption {
return func(s *SessionTransport) error {
s.checkPeerID = enable
return nil
}
}

var _ sec.SecureTransport = &SessionTransport{}

// SessionTransport can be used
// to provide per-connection options
type SessionTransport struct {
t *Transport
// options
prologue []byte
prologue []byte
checkPeerID bool

initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler
}

// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, false)
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, false, i.checkPeerID)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
Expand All @@ -77,5 +85,5 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn,

// SecureOutbound runs the Noise handshake as the initiator.
func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true)
return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true, i.checkPeerID)
}
12 changes: 8 additions & 4 deletions p2p/security/noise/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
}

// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
// if p is empty accept any peer ID.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false)
checkPeerID := true
if p == "" {
checkPeerID = false
}
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false, checkPeerID)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
Expand All @@ -53,10 +57,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer

// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true)
return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true, true)
}

func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) {
func (t *Transport) WithSessionOptions(opts ...SessionOption) (*SessionTransport, error) {
st := &SessionTransport{t: t}
for _, opt := range opts {
if err := opt(st); err != nil {
Expand Down
27 changes: 23 additions & 4 deletions p2p/security/noise/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID)
}()

respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "")
receiverTpt, _ := respTransport.WithSessionOptions(CheckPeerID(false))
respConn, respErr := receiverTpt.SecureInbound(context.Background(), resp, "")
<-done

if initErr != nil {
Expand Down Expand Up @@ -190,7 +191,7 @@ func TestPeerIDMatch(t *testing.T) {

func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, _ := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(CheckPeerID(false))
init, resp := newConnPair(t)

errChan := make(chan error)
Expand Down Expand Up @@ -226,6 +227,24 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
<-done
}

func TestPeerIDOutboundNoCheck(t *testing.T) {
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(CheckPeerID(false))
require.NoError(t, err, "could not initiate session transport")
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := initTransport.SecureOutbound(context.Background(), init, "")
errChan <- err
}()

_, err = respTransport.SecureInbound(context.Background(), resp, "")
require.NoError(t, err)
initErr := <-errChan
require.NoError(t, initErr)
}

func makeLargePlaintext(size int) []byte {
buf := make([]byte, size)
rand.Read(buf)
Expand Down Expand Up @@ -565,7 +584,7 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) {
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH, nil))
require.NoError(t, err)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, _ := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(CheckPeerID(false))

initConn, respConn := newConnPair(t)

Expand All @@ -575,7 +594,7 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) {
errChan <- err
}()

conn, err := initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID)
conn, err := initTransport.SecureOutbound(context.Background(), respConn, respTransport.t.localID)
require.NoError(t, err)
defer conn.Close()

Expand Down

0 comments on commit 10cb830

Please sign in to comment.