Skip to content

Commit 8e27201

Browse files
committed
Using url.Parse to improve address parsing
1 parent 2f2d29f commit 8e27201

File tree

4 files changed

+60
-117
lines changed

4 files changed

+60
-117
lines changed

core/outbound/clients/resolver/address.go

+40-79
Original file line numberDiff line numberDiff line change
@@ -42,76 +42,61 @@ func ToNetwork(protocol string) string {
4242
}
4343
}
4444

45-
// ExtractSocksAddress parse socks5 address,
46-
// support two formats: socks5://127.0.0.1:1080 and 127.0.0.1:1080
47-
func ExtractSocksAddress(rawAddress string) (string, error) {
45+
// support two formats: scheme://127.0.0.1:1080 or 127.0.0.1:1080
46+
func extractUrl(rawAddress string, protocol string) (host string, port string, err error) {
47+
48+
if !strings.Contains(rawAddress, "://") {
49+
rawAddress = protocol + "://" + rawAddress
50+
}
51+
4852
uri, err := url.Parse(rawAddress)
4953
if err != nil {
50-
// socks5 address format is 127.0.0.1:1080
51-
_, _, err = net.SplitHostPort(rawAddress)
52-
isJustIP := isJustIP(rawAddress)
53-
if err != nil && !isJustIP {
54-
log.Warnf("socks5 address %s is invalid", rawAddress)
55-
return "", errors.New("socks5 address is invalid")
56-
}
57-
if isJustIP {
58-
rawAddress = rawAddress + ":" + getDefaultPort("socks5")
59-
}
60-
return rawAddress, nil
54+
log.Warnf("url %s is invalid", rawAddress)
55+
return "", "", errors.New("url is invalid")
6156
}
62-
// socks5://127.0.0.1:1080
63-
if len(uri.Scheme) == 0 || uri.Scheme != "socks5" {
64-
return "", errors.New("socks5 address is invalid")
57+
host = uri.Hostname()
58+
59+
if len(uri.Scheme) == 0 || uri.Scheme != protocol {
60+
return "", "", errors.New("url is invalid")
6561
}
66-
port := uri.Port()
62+
63+
port = uri.Port()
6764
if len(port) == 0 {
68-
port = "1080"
65+
port = getDefaultPort(protocol)
6966
}
70-
address := net.JoinHostPort(uri.Hostname(), port)
71-
return address, nil
67+
return
7268
}
7369

74-
// ExtractTLSDNSAddress parse tcp-tls format: dns.google:853@8.8.8.8
75-
func ExtractTLSDNSAddress(rawAddress string) (host string, port string, ip string, err error) {
70+
func ExtractFullUrl(rawAddress string, protocol string) (string, error) {
71+
host, port, err := extractUrl(rawAddress, protocol)
72+
return net.JoinHostPort(host, port), err
73+
}
74+
75+
func extractTLSDNSAddress(rawAddress string, protocol string) (host string, port string, err error) {
76+
rawAddress = protocol + "://" + rawAddress
7677
s := strings.Split(rawAddress, "@")
77-
host, port, err = net.SplitHostPort(s[0])
78-
isJustHost := len(rawAddress) > 0
79-
if err != nil && !isJustHost {
80-
log.Warnf("dns server address %s is invalid", rawAddress)
81-
return "", "", "", errors.New("dns up server address is invalid")
82-
}
83-
if err != nil && isJustHost {
84-
host = s[0]
85-
if isJustIP(host) {
86-
host = generateLiteralIPv6AddressIfNecessary(host)
87-
}
88-
port = getDefaultPort("tcp-tls")
89-
}
9078

91-
ip = s[1]
92-
if isJustIP(ip) {
93-
ip = generateLiteralIPv6AddressIfNecessary(ip)
94-
} else {
95-
log.Warnf("dns server address %s is invalid", rawAddress)
96-
return "", "", "", errors.New("dns up server address is invalid")
79+
host, port, err = extractUrl(s[0], protocol)
80+
81+
if err != nil {
82+
return "", "", nil
9783
}
98-
return host, port, ip, nil
99-
}
10084

101-
// extractNormalDNSAddress parse normal format: 8.8.8.8:53
102-
func extractNormalDNSAddress(rawAddress string, protocol string) (host string, port string, err error) {
103-
host, port, err = net.SplitHostPort(rawAddress)
104-
isJustIP := isJustIP(rawAddress)
105-
if err != nil && !isJustIP {
85+
if len(s) == 2 && isJustIP(s[1]) {
86+
host = generateLiteralIPv6AddressIfNecessary(s[1])
87+
} else {
10688
log.Warnf("dns server address %s is invalid", rawAddress)
10789
return "", "", errors.New("dns up server address is invalid")
10890
}
109-
if isJustIP {
110-
host = generateLiteralIPv6AddressIfNecessary(rawAddress)
111-
port = getDefaultPort(protocol)
112-
}
11391
return host, port, nil
92+
}
11493

94+
func ExtractTLSDNSHostName(rawAddress string) (host string, err error) {
95+
rawAddress = "tcp-tls" + "://" + rawAddress
96+
s := strings.Split(rawAddress, "@")
97+
98+
host, _, err = extractUrl(s[0], "tcp-tls")
99+
return host, err
115100
}
116101

117102
func isJustIP(rawAddress string) bool {
@@ -128,37 +113,13 @@ func generateLiteralIPv6AddressIfNecessary(rawAddress string) string {
128113
return rawAddress
129114
}
130115

131-
// extractHTTPSAddress parse https format: https://dns.google/dns-query
132-
func extractHTTPSAddress(rawAddress string) (host string, port string, err error) {
133-
uri, err := url.Parse(rawAddress)
134-
if err != nil {
135-
return "", "", err
136-
}
137-
host = uri.Hostname()
138-
port = uri.Port()
139-
if len(port) == 0 {
140-
port = getDefaultPort("https")
141-
}
142-
return host, port, nil
143-
144-
}
145-
146116
// ExtractDNSAddress parse all format, return literal IPv6 address
147117
func ExtractDNSAddress(rawAddress string, protocol string) (host string, port string, err error) {
148118
switch protocol {
149-
case "https":
150-
host, port, err = extractHTTPSAddress(rawAddress)
151119
case "tcp-tls":
152-
_host, _port, _ip, _err := ExtractTLSDNSAddress(rawAddress)
153-
if len(_ip) > 0 {
154-
host = _ip
155-
} else {
156-
host = _host
157-
}
158-
port = _port
159-
err = _err
120+
host, port, err = extractTLSDNSAddress(rawAddress, protocol)
160121
default:
161-
host, port, err = extractNormalDNSAddress(rawAddress, protocol)
122+
host, port, err = extractUrl(rawAddress, protocol)
162123
}
163124
return host, port, err
164125
}

core/outbound/clients/resolver/address_test.go

+18-36
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func TestExtractDNSAddress(t *testing.T) {
2525
{ipv4Address, "udp", ipv4Address, "53", nil},
2626
{ipv6Address, "udp", literalIpa6Address, "53", nil},
2727
{"https://dns.google/dns-query", "https", "dns.google", "443", nil},
28+
{"dns.google/dns-query", "https", "dns.google", "443", nil},
2829
{"https://dns.google:888/dns-query", "https", "dns.google", "888", nil},
2930
}
3031
for _, tt := range tests {
@@ -37,46 +38,27 @@ func TestExtractDNSAddress(t *testing.T) {
3738
}
3839
}
3940

40-
func TestExtractSocksAddress(t *testing.T) {
41+
func TestExtractFullUrl(t *testing.T) {
4142
var tests = []struct {
42-
in string
43-
out string
43+
url string
44+
protocol string
45+
out string
4446
}{
45-
{"socks5://" + ipv4Address + ":80", ipv4Address + ":80"},
46-
{"socks5://" + ipv6Address + ":80", ipv6Address + ":80"},
47-
{"socks5://" + ipv6Address, ipv6Address + ":1080"},
48-
{"" + ipv4Address + ":80", ipv4Address + ":80"},
49-
{"" + ipv6Address + ":80", ipv6Address + ":80"},
50-
{"" + ipv6Address, ipv6Address + ":1080"},
47+
{"socks5://" + ipv4Address + ":80", "socks5", ipv4Address + ":80"},
48+
{ipv4Address + ":80", "socks5", ipv4Address + ":80"},
49+
{ipv6Address + ":80", "socks5", ipv6Address + ":80"},
50+
{ipv6Address, "socks5", ipv6Address + ":1080"},
51+
{ipv6Address, "https", ipv6Address + ":443"},
52+
{"tcp-tls://" + ipv6Address, "tcp-tls", ipv6Address + ":853"},
53+
{"" + ipv4Address + ":80", "socks5", ipv4Address + ":80"},
54+
{"" + ipv6Address + ":80", "socks5", ipv6Address + ":80"},
55+
{"" + ipv6Address, "socks5", ipv6Address + ":1080"},
56+
{"abc.com", "socks5", "abc.com:1080"},
5157
}
5258
for _, tt := range tests {
53-
t.Run(tt.in, func(t *testing.T) {
54-
addr, err := ExtractSocksAddress(tt.in)
55-
testEqual(t, addr, tt.out)
56-
testErr(t, err)
57-
})
58-
}
59-
}
60-
61-
func TestExtractTLSDNSAddress(t *testing.T) {
62-
63-
var tests = []struct {
64-
in string
65-
host string
66-
port string
67-
ip string
68-
err error
69-
}{
70-
{"dns.google:853@" + ipv6Address, "dns.google", "853", literalIpa6Address, nil},
71-
{"dns.google@" + ipv6Address, "dns.google", "853", literalIpa6Address, nil},
72-
{"dns.google:853@" + ipv4Address, "dns.google", "853", ipv4Address, nil},
73-
}
74-
for _, tt := range tests {
75-
t.Run(tt.in, func(t *testing.T) {
76-
host, port, ip, err := ExtractTLSDNSAddress(tt.in)
77-
testEqual(t, host, tt.host)
78-
testEqual(t, port, tt.port)
79-
testEqual(t, ip, tt.ip)
59+
t.Run(tt.url, func(t *testing.T) {
60+
url, err := ExtractFullUrl(tt.url, tt.protocol)
61+
testEqual(t, url, tt.out)
8062
testErr(t, err)
8163
})
8264
}

core/outbound/clients/resolver/base_resolver.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (r *BaseResolver) CreateBaseConn() (net.Conn, error) {
9494
dialer := net.Dialer{Timeout: r.getDialTimeout()}
9595
dialerFunc := dialer.Dial
9696
if r.dnsUpstream.SOCKS5Address != "" {
97-
socksAddress, err := ExtractSocksAddress(r.dnsUpstream.SOCKS5Address)
97+
socksAddress, err := ExtractFullUrl(r.dnsUpstream.SOCKS5Address, "socks5")
9898
if err != nil {
9999
return nil, err
100100
}

core/outbound/clients/resolver/tcptls_resolver.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func (r *TCPTLSResolver) createTlsConn() (conn net.Conn, err error) {
3333
if err != nil {
3434
return nil, err
3535
}
36-
host, _, _, err := ExtractTLSDNSAddress(r.dnsUpstream.Address)
36+
host, err := ExtractTLSDNSHostName(r.dnsUpstream.Address)
3737
if err != nil {
3838
return nil, err
3939
}

0 commit comments

Comments
 (0)