Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p2p: add dynamic DNS resolution for nodes #30822

Merged
merged 14 commits into from
Dec 13, 2024
Merged
46 changes: 38 additions & 8 deletions p2p/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ type dialConfig struct {
log log.Logger
clock mclock.Clock
rand *mrand.Rand
ttl time.Duration
}

func (cfg dialConfig) withDefaults() dialConfig {
Expand Down Expand Up @@ -274,13 +275,15 @@ loop:
case node := <-d.addStaticCh:
id := node.ID()
_, exists := d.static[id]
d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists)
d.log.Trace("Adding static node", "id", id, "addr", node.DisplayAddr(), "ip", node.IPAddr(), "added", !exists)
if exists {
continue loop
}
task := newDialTask(node, staticDialedConn)
d.static[id] = task
if d.checkDial(node) == nil {
if err := d.checkDial(node); err != nil {
d.log.Trace("Discarding dial candidate", "id", node.ID(), "addr", node.DisplayAddr(), "ip", node.IPAddr(), "reason", err)
} else {
d.addToStaticPool(task)
}

Expand Down Expand Up @@ -436,7 +439,7 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
// startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) {
node := task.dest()
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags)
d.log.Trace("Starting p2p dial", "id", node.ID(), "addr", node.DisplayAddr(), "ip", node.IPAddr(), "flag", task.flags)
hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[node.ID()] = task
Expand Down Expand Up @@ -473,7 +476,7 @@ func (t *dialTask) dest() *enode.Node {
}

func (t *dialTask) run(d *dialScheduler) {
if t.needResolve() && !t.resolve(d) {
if !t.resolveIfNeeded(d) {
return
}

Expand All @@ -488,8 +491,36 @@ func (t *dialTask) run(d *dialScheduler) {
}
}

func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid()
// resolveIfNeeded attempts to resolve the node's IP address if it is invalid.
// It returns true if the node's IP address is valid after resolution attempts.
func (t *dialTask) resolveIfNeeded(d *dialScheduler) bool {
node := t.dest()

if t.flags&staticDialedConn != 0 {
if t.resolve(d) && node.IPAddr().IsValid() {
return true
}
}

if node.NeedsDNSResolve() {
if t.resolveDNS(d) && node.IPAddr().IsValid() {
return true
}
}

return false
}

// resolveDNS attempts to resolve the DNS name of the destination node.
// It returns true if resolution succeeds.
func (t *dialTask) resolveDNS(d *dialScheduler) bool {
node := t.dest()
d.log.Trace("Starting DNS resolution", "id", node.ID(), "addr", node.DisplayAddr())
if err := node.RefreshDNS(d.dialConfig.ttl); err != nil {
d.log.Trace("DNS resolution failed", "id", node.ID(), "addr", node.DisplayAddr())
return false
}
return true
}

// resolve attempts to find the current endpoint for the destination
Expand Down Expand Up @@ -533,8 +564,7 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
dialMeter.Mark(1)
fd, err := d.dialer.Dial(d.ctx, dest)
if err != nil {
addr, _ := dest.TCPEndpoint()
d.log.Trace("Dial error", "id", dest.ID(), "addr", addr, "conn", t.flags, "err", cleanupDialErr(err))
d.log.Trace("Dial error", "id", dest.ID(), "addr", dest.DisplayAddr(), "ip", dest.IPAddr(), "conn", t.flags, "err", cleanupDialErr(err))
dialConnectionError.Mark(1)
return &dialError{err}
}
Expand Down
83 changes: 83 additions & 0 deletions p2p/enode/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net"
"net/netip"
"strings"
"time"

"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
Expand All @@ -41,6 +42,10 @@ type Node struct {
ip netip.Addr
udp uint16
tcp uint16
// dns information
dnsName string
dnsResolved time.Time
dnsTTL time.Duration
}

// New wraps a node record. The record must be valid according to the given
Expand Down Expand Up @@ -257,6 +262,84 @@ func (n *Node) String() string {
return "enr:" + b64
}

// resolveDNS attempts to resolve a DNS name to an IP address
func (n *Node) resolveDNS(dnsName string) (netip.Addr, error) {
ips, err := net.LookupIP(dnsName)
if err != nil {
return netip.Addr{}, err
}
for _, ip := range ips {
if ip4 := ip.To4(); ip4 != nil {
addr, ok := netip.AddrFromSlice(ip4)
if ok {
return addr, nil
}
}
}
// Fall back to IPv6 if no IPv4 is available
for _, ip := range ips {
addr, ok := netip.AddrFromSlice(ip)
if ok {
return addr, nil
}
}

return netip.Addr{}, errors.New("no valid IP address found")
}

// SetDNS sets the DNS name and resolves it to an IP address
func (n *Node) SetDNS(dnsName string, ttl time.Duration) error {
ip, err := n.resolveDNS(dnsName)
if err != nil {
return err
}

n.dnsName = dnsName
n.dnsResolved = time.Now()
n.dnsTTL = ttl

if ip.Is4() {
n.setIP4(ip)
} else {
n.setIP6(ip)
}
return nil
}

// DNSName returns the stored DNS name
func (n *Node) DNSName() string {
return n.dnsName
}

// RefreshDNS updates the IP address from the stored DNS name
func (n *Node) RefreshDNS(ttl time.Duration) error {
if n.dnsName == "" {
return errors.New("no DNS name set")
}
return n.SetDNS(n.dnsName, ttl)
}

// DisplayAddr returns either "hostname:port" or "ip:port"
func (n *Node) DisplayAddr() string {
addr := n.dnsName
if addr == "" {
addr = n.ip.String()
}
return fmt.Sprintf("%s:%d", addr, n.tcp)
}

// NeedsDNSResolve returns true if the node has a DNS name that needs resolution
func (n *Node) NeedsDNSResolve() bool {
if n.dnsName == "" {
return false
}
return !n.ip.IsValid() || time.Since(n.dnsResolved) > n.dnsTTL
}

func (n *Node) GetTTL() time.Duration {
return n.dnsTTL
}

// MarshalText implements encoding.TextMarshaler.
func (n *Node) MarshalText() ([]byte, error) {
return []byte(n.String()), nil
Expand Down
52 changes: 52 additions & 0 deletions p2p/enode/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func TestNodeEndpoints(t *testing.T) {
wantUDP int
wantTCP int
wantQUIC int
wantDNS string
}
tests := []endpointTest{
{
Expand Down Expand Up @@ -268,6 +269,54 @@ func TestNodeEndpoints(t *testing.T) {
wantIP: netip.MustParseAddr("2001::ff00:0042:8329"),
wantQUIC: 9001,
},
{
name: "dns-only",
node: func() *Node {
var r enr.Record
n := SignNull(&r, id)
n.dnsName = "example.com"
n.tcp = 30303
n.udp = 30303
return n
}(),
wantTCP: 30303,
wantUDP: 30303,
wantDNS: "example.com",
},
{
name: "dns-with-ports",
node: func() *Node {
var r enr.Record
r.Set(enr.TCP(9000))
r.Set(enr.UDP(9001))
n := SignNull(&r, id)
n.dnsName = "node.example.org"
n.tcp = 9000
n.udp = 9001
return n
}(),
wantTCP: 9000,
wantUDP: 9001,
wantDNS: "node.example.org",
},
{
name: "dns-with-ip-fallback",
node: func() *Node {
var r enr.Record
r.Set(enr.IPv4Addr(netip.MustParseAddr("192.168.1.1")))
r.Set(enr.TCP(9000))
r.Set(enr.UDP(9000))
n := SignNull(&r, id)
n.dnsName = "node.example.org"
n.tcp = 9000
n.udp = 9000
return n
}(),
wantIP: netip.MustParseAddr("192.168.1.1"),
wantTCP: 9000,
wantUDP: 9000,
wantDNS: "node.example.org",
},
}

for _, test := range tests {
Expand All @@ -284,6 +333,9 @@ func TestNodeEndpoints(t *testing.T) {
if quic, _ := test.node.QUICEndpoint(); test.wantQUIC != int(quic.Port()) {
t.Errorf("node has wrong QUIC port %d, want %d", quic.Port(), test.wantQUIC)
}
if test.wantDNS != test.node.DNSName() {
t.Errorf("node has wrong DNS name %s, want %s", test.node.DNSName(), test.wantDNS)
}
})
}
}
Expand Down
55 changes: 36 additions & 19 deletions p2p/enode/urlv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ func NewV4(pubkey *ecdsa.PublicKey, ip net.IP, tcp, udp int) *Node {
return n
}

func NewV4WithDNS(pubkey *ecdsa.PublicKey, ip net.IP, dnsName string, tcp, udp int) *Node {
n := NewV4(pubkey, ip, tcp, udp)
// Always set TCP/UDP ports regardless of IP
// This is to ensure that the node is always
// considered valid even if the IP is not
// set.
if len(ip) == 0 {
n.tcp = uint16(tcp)
n.udp = uint16(udp)
}
n.dnsName = dnsName
return n
}

// isNewV4 returns true for nodes created by NewV4.
func isNewV4(n *Node) bool {
var k s256raw
Expand All @@ -126,20 +140,6 @@ func parseComplete(rawurl string) (*Node, error) {
if id, err = parsePubkey(u.User.String()); err != nil {
return nil, fmt.Errorf("invalid public key (%v)", err)
}
// Parse the IP address.
ip := net.ParseIP(u.Hostname())
if ip == nil {
ips, err := lookupIPFunc(u.Hostname())
if err != nil {
return nil, err
}
ip = ips[0]
}
// Ensure the IP is 4 bytes long for IPv4 addresses.
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
// Parse the port numbers.
if tcpPort, err = strconv.ParseUint(u.Port(), 10, 16); err != nil {
return nil, errors.New("invalid port")
}
Expand All @@ -151,6 +151,19 @@ func parseComplete(rawurl string) (*Node, error) {
return nil, errors.New("invalid discport in query")
}
}
// Check if hostname is an IP address and create node accordingly
hostname := u.Hostname()
ip := net.ParseIP(hostname)
if ip == nil {
ips, err := lookupIPFunc(hostname)
if err != nil {
return NewV4WithDNS(id, nil, hostname, int(tcpPort), int(udpPort)), nil
}
ip = ips[0]
}
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
return NewV4(id, ip, int(tcpPort), int(udpPort)), nil
}

Expand Down Expand Up @@ -181,15 +194,19 @@ func (n *Node) URLv4() string {
nodeid = fmt.Sprintf("%s.%x", scheme, n.id[:])
}
u := url.URL{Scheme: "enode"}
if !n.ip.IsValid() {
if !n.ip.IsValid() && n.dnsName == "" {
u.Host = nodeid
return u.String()
}
u.User = url.User(nodeid)
if n.dnsName != "" {
u.Host = fmt.Sprintf("%s:%d", n.dnsName, n.TCP())
} else {
addr := net.TCPAddr{IP: n.IP(), Port: n.TCP()}
u.User = url.User(nodeid)
u.Host = addr.String()
if n.UDP() != n.TCP() {
u.RawQuery = "discport=" + strconv.Itoa(n.UDP())
}
}
if n.UDP() != n.TCP() {
u.RawQuery = "discport=" + strconv.Itoa(n.UDP())
}
return u.String()
}
Expand Down
14 changes: 10 additions & 4 deletions p2p/enode/urlv4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ var parseNodeTests = []struct {
wantError: enr.ErrInvalidSig.Error(),
},
// Complete node URLs with IP address and ports
{
input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@invalid.:3",
wantError: `no such host`,
},
{
input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
wantError: `invalid port`,
Expand All @@ -82,6 +78,16 @@ var parseNodeTests = []struct {
input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
wantError: `invalid discport in query`,
},
{
input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@valid.:3",
wantResult: NewV4WithDNS(
hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
nil,
"valid.",
3,
3,
),
},
{
input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
wantResult: NewV4(
Expand Down
Loading