Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p2p: add dynamic DNS resolution for nodes #30822

Merged
merged 14 commits into from
Dec 13, 2024
Merged
143 changes: 114 additions & 29 deletions p2p/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ import (
"fmt"
mrand "math/rand"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"

"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
)

Expand Down Expand Up @@ -77,6 +79,7 @@ var (
errRecentlyDialed = errors.New("recently dialed")
errNetRestrict = errors.New("not contained in netrestrict list")
errNoPort = errors.New("node does not provide TCP port")
errNoResolvedIP = errors.New("node does not provide a resolved IP")
)

// dialer creates outbound connections and submits them into Server.
Expand All @@ -90,16 +93,17 @@ var (
// to create peer connections to nodes arriving through the iterator.
type dialScheduler struct {
dialConfig
setupFunc dialSetupFunc
wg sync.WaitGroup
cancel context.CancelFunc
ctx context.Context
nodesIn chan *enode.Node
doneCh chan *dialTask
addStaticCh chan *enode.Node
remStaticCh chan *enode.Node
addPeerCh chan *conn
remPeerCh chan *conn
setupFunc dialSetupFunc
dnsLookupFunc func(ctx context.Context, network string, name string) ([]netip.Addr, error)
wg sync.WaitGroup
cancel context.CancelFunc
ctx context.Context
nodesIn chan *enode.Node
doneCh chan *dialTask
addStaticCh chan *enode.Node
remStaticCh chan *enode.Node
addPeerCh chan *conn
remPeerCh chan *conn

// Everything below here belongs to loop and
// should only be accessed by code on the loop goroutine.
Expand Down Expand Up @@ -159,18 +163,19 @@ func (cfg dialConfig) withDefaults() dialConfig {
func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler {
cfg := config.withDefaults()
d := &dialScheduler{
dialConfig: cfg,
historyTimer: mclock.NewAlarm(cfg.clock),
setupFunc: setupFunc,
dialing: make(map[enode.ID]*dialTask),
static: make(map[enode.ID]*dialTask),
peers: make(map[enode.ID]struct{}),
doneCh: make(chan *dialTask),
nodesIn: make(chan *enode.Node),
addStaticCh: make(chan *enode.Node),
remStaticCh: make(chan *enode.Node),
addPeerCh: make(chan *conn),
remPeerCh: make(chan *conn),
dialConfig: cfg,
historyTimer: mclock.NewAlarm(cfg.clock),
setupFunc: setupFunc,
dnsLookupFunc: net.DefaultResolver.LookupNetIP,
dialing: make(map[enode.ID]*dialTask),
static: make(map[enode.ID]*dialTask),
peers: make(map[enode.ID]struct{}),
doneCh: make(chan *dialTask),
nodesIn: make(chan *enode.Node),
addStaticCh: make(chan *enode.Node),
remStaticCh: make(chan *enode.Node),
addPeerCh: make(chan *conn),
remPeerCh: make(chan *conn),
}
d.lastStatsLog = d.clock.Now()
d.ctx, d.cancel = context.WithCancel(context.Background())
Expand Down Expand Up @@ -274,7 +279,7 @@ loop:
case node := <-d.addStaticCh:
id := node.ID()
_, exists := d.static[id]
d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists)
d.log.Trace("Adding static node", "id", id, "endpoint", nodeEndpointForLog(node), "added", !exists)
if exists {
continue loop
}
Expand Down Expand Up @@ -433,10 +438,68 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
task.staticPoolIndex = -1
}

// dnsResolveHostname updates the given node from its DNS hostname.
// This is used to resolve static dial targets.
func (d *dialScheduler) dnsResolveHostname(n *enode.Node) (*enode.Node, error) {
if n.Hostname() == "" {
return n, nil
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
foundIPs, err := d.dnsLookupFunc(ctx, "ip", n.Hostname())
if err != nil {
return n, err
}

// Check for IP updates.
var (
nodeIP4, nodeIP6 netip.Addr
foundIP4, foundIP6 netip.Addr
)
n.Load((*enr.IPv4Addr)(&nodeIP4))
n.Load((*enr.IPv6Addr)(&nodeIP6))
for _, ip := range foundIPs {
if ip.Is4() && !foundIP4.IsValid() {
foundIP4 = ip
}
if ip.Is6() && !foundIP6.IsValid() {
foundIP6 = ip
}
}

if !foundIP4.IsValid() && !foundIP6.IsValid() {
// Lookup failed.
return n, errNoResolvedIP
}
if foundIP4 == nodeIP4 && foundIP6 == nodeIP6 {
// No updates necessary.
d.log.Trace("Node DNS lookup had no update", "id", n.ID(), "name", n.Hostname(), "ip", foundIP4, "ip6", foundIP6)
return n, nil
}

// Update the node. Note this invalidates the ENR signature, because we use SignNull
// to create a modified copy. But this should be OK, since we just use the node as a
// dial target. And nodes will usually only have a DNS hostname if they came from a
// enode:// URL, which has no signature anyway. If it ever becomes a problem, the
// resolved IP could also be stored into dialTask instead of the node.
rec := n.Record()
if foundIP4.IsValid() {
rec.Set(enr.IPv4Addr(foundIP4))
}
if foundIP6.IsValid() {
rec.Set(enr.IPv6Addr(foundIP6))
}
rec.SetSeq(n.Seq()) // ensure seq not bumped by update
newNode := enode.SignNull(rec, n.ID()).WithHostname(n.Hostname())
d.log.Debug("Node updated from DNS lookup", "id", n.ID(), "name", n.Hostname(), "ip", newNode.IP())
return newNode, nil
}

// startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) {
node := task.dest()
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags)
d.log.Trace("Starting p2p dial", "id", node.ID(), "endpoint", nodeEndpointForLog(node), "flag", task.flags)
hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[node.ID()] = task
Expand Down Expand Up @@ -473,23 +536,38 @@ func (t *dialTask) dest() *enode.Node {
}

func (t *dialTask) run(d *dialScheduler) {
if t.needResolve() && !t.resolve(d) {
return
if t.isStatic() {
// Resolve DNS.
if n := t.dest(); n.Hostname() != "" {
resolved, err := d.dnsResolveHostname(n)
if err != nil {
d.log.Warn("DNS lookup of static node failed", "id", n.ID(), "name", n.Hostname(), "err", err)
} else {
t.destPtr.Store(resolved)
}
}
// Try resolving node ID through the DHT if there is no IP address.
if !t.dest().IPAddr().IsValid() {
if !t.resolve(d) {
return // DHT resolve failed, skip dial.
}
}
}

err := t.dial(d, t.dest())
if err != nil {
// For static nodes, resolve one more time if dialing fails.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
var dialErr *dialError
if errors.As(err, &dialErr) && t.isStatic() {
if t.resolve(d) {
t.dial(d, t.dest())
}
}
}
}

func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid()
func (t *dialTask) isStatic() bool {
return t.flags&staticDialedConn != 0
}

// resolve attempts to find the current endpoint for the destination
Expand Down Expand Up @@ -553,3 +631,10 @@ func cleanupDialErr(err error) error {
}
return err
}

func nodeEndpointForLog(n *enode.Node) string {
if n.Hostname() != "" {
return n.Hostname()
}
return n.IPAddr().String()
}
29 changes: 29 additions & 0 deletions p2p/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"reflect"
"sync"
"testing"
Expand Down Expand Up @@ -394,6 +395,34 @@ func TestDialSchedResolve(t *testing.T) {
})
}

func TestDialSchedDNSHostname(t *testing.T) {
t.Parallel()

config := dialConfig{
maxActiveDials: 1,
maxDialPeers: 1,
}
node := newNode(uintID(0x01), ":30303").WithHostname("node-hostname")
resolved := newNode(uintID(0x01), "1.2.3.4:30303").WithHostname("node-hostname")
runDialTest(t, config, []dialTestRound{
{
update: func(d *dialScheduler) {
d.dnsLookupFunc = func(ctx context.Context, network string, name string) ([]netip.Addr, error) {
if name != "node-hostname" {
t.Error("wrong hostname in DNS lookup:", name)
}
result := []netip.Addr{netip.MustParseAddr("1.2.3.4")}
return result, nil
}
d.addStatic(node)
},
wantNewDials: []*enode.Node{
resolved,
},
},
})
}

// -------
// Code below here is the framework for the tests above.

Expand Down
22 changes: 22 additions & 0 deletions p2p/enode/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ var errMissingPrefix = errors.New("missing 'enr:' prefix for base64-encoded reco
type Node struct {
r enr.Record
id ID

// hostname tracks the DNS name of the node.
hostname string

// endpoint information
ip netip.Addr
udp uint16
Expand Down Expand Up @@ -77,6 +81,8 @@ func newNodeWithID(r *enr.Record, id ID) *Node {
n.setIP4(ip4)
case valid6:
n.setIP6(ip6)
default:
n.setIPv4Ports()
}
return n
}
Expand All @@ -103,6 +109,10 @@ func localityScore(ip netip.Addr) int {

func (n *Node) setIP4(ip netip.Addr) {
n.ip = ip
n.setIPv4Ports()
}

func (n *Node) setIPv4Ports() {
n.Load((*enr.UDP)(&n.udp))
n.Load((*enr.TCP)(&n.tcp))
}
Expand Down Expand Up @@ -184,6 +194,18 @@ func (n *Node) TCP() int {
return int(n.tcp)
}

// WithHostname adds a DNS hostname to the node.
func (n *Node) WithHostname(hostname string) *Node {
cpy := *n
cpy.hostname = hostname
return &cpy
}

// Hostname returns the DNS name assigned by WithHostname.
func (n *Node) Hostname() string {
return n.hostname
}

// UDPEndpoint returns the announced UDP endpoint.
func (n *Node) UDPEndpoint() (netip.AddrPort, bool) {
if !n.ip.IsValid() || n.ip.IsUnspecified() || n.udp == 0 {
Expand Down
19 changes: 19 additions & 0 deletions p2p/enode/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func TestNodeEndpoints(t *testing.T) {
wantUDP int
wantTCP int
wantQUIC int
wantDNS string
}
tests := []endpointTest{
{
Expand All @@ -90,6 +91,7 @@ func TestNodeEndpoints(t *testing.T) {
r.Set(enr.UDP(9000))
return SignNull(&r, id)
}(),
wantUDP: 9000,
},
{
name: "tcp-only",
Expand All @@ -98,6 +100,7 @@ func TestNodeEndpoints(t *testing.T) {
r.Set(enr.TCP(9000))
return SignNull(&r, id)
}(),
wantTCP: 9000,
},
{
name: "quic-only",
Expand Down Expand Up @@ -268,6 +271,19 @@ func TestNodeEndpoints(t *testing.T) {
wantIP: netip.MustParseAddr("2001::ff00:0042:8329"),
wantQUIC: 9001,
},
{
name: "dns-only",
node: func() *Node {
var r enr.Record
r.Set(enr.UDP(30303))
r.Set(enr.TCP(30303))
n := SignNull(&r, id).WithHostname("example.com")
return n
}(),
wantTCP: 30303,
wantUDP: 30303,
wantDNS: "example.com",
},
}

for _, test := range tests {
Expand All @@ -284,6 +300,9 @@ func TestNodeEndpoints(t *testing.T) {
if quic, _ := test.node.QUICEndpoint(); test.wantQUIC != int(quic.Port()) {
t.Errorf("node has wrong QUIC port %d, want %d", quic.Port(), test.wantQUIC)
}
if test.wantDNS != test.node.Hostname() {
t.Errorf("node has wrong DNS name %s, want %s", test.node.Hostname(), test.wantDNS)
}
})
}
}
Expand Down
Loading