Skip to content

Commit 601b46e

Browse files
authored
修改1.timeout,2. 修改测试代码 (#28)
1 parent e2ce486 commit 601b46e

9 files changed

+90
-46
lines changed

client.go

+7-20
Original file line numberDiff line numberDiff line change
@@ -196,34 +196,32 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) {
196196
}
197197

198198
var conn net.Conn
199-
begin := time.Now()
200199

201200
hostName := hostname.GetHostName(d.u)
202-
// conn, err := net.DialTimeout("tcp", d.u.Host /* TODO 加端号*/, d.dialTimeout)
203-
dialFunc := net.Dial
201+
dialFunc := net.DialTimeout
204202
if d.dialFunc != nil {
205203
dialInterface, err := d.dialFunc()
206204
if err != nil {
207205
return nil, err
208206
}
209-
dialFunc = dialInterface.Dial
207+
dialFunc = func(network, address string, timeout time.Duration) (net.Conn, error) {
208+
return dialInterface.Dial(network, address)
209+
}
210210
}
211211

212212
if d.proxyFunc != nil {
213213
proxyURL, err := d.proxyFunc(req)
214214
if err != nil {
215215
return nil, err
216216
}
217-
dialFunc = newhttpProxy(proxyURL, dialFunc).Dial
217+
dialFunc = newhttpProxy(proxyURL, dialFunc).DialTimeout
218218
}
219219

220-
conn, err = dialFunc("tcp", hostName)
220+
conn, err = dialFunc("tcp", hostName, d.dialTimeout)
221221
if err != nil {
222222
return nil, err
223223
}
224224

225-
dialDuration := time.Since(begin)
226-
227225
conn = d.tlsConn(conn)
228226
defer func() {
229227
if err != nil && conn != nil {
@@ -232,18 +230,7 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) {
232230
}
233231
}()
234232

235-
if to := d.dialTimeout - dialDuration; to > 0 {
236-
if err = conn.SetDeadline(time.Now().Add(to)); err != nil {
237-
return
238-
}
239-
}
240-
241-
defer func() {
242-
if err == nil {
243-
err = conn.SetDeadline(time.Time{})
244-
}
245-
}()
246-
233+
err = conn.SetDeadline(time.Time{})
247234
if err = req.Write(conn); err != nil {
248235
return
249236
}

common_options_test.go

+48-2
Original file line numberDiff line numberDiff line change
@@ -2227,13 +2227,14 @@ func Test_CommonOption(t *testing.T) {
22272227
t.Run("22.3.WithClientReadMaxMessage", func(t *testing.T) {
22282228
var tsort testServerOptionReadTimeout
22292229

2230-
upgrade := NewUpgrade(WithServerCallback(&tsort), WithServerReadTimeout(time.Millisecond*60))
2230+
upgrade := NewUpgrade()
22312231
tsort.err = make(chan error, 1)
22322232
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22332233
c, err := upgrade.Upgrade(w, r)
22342234
if err != nil {
22352235
t.Error(err)
22362236
}
2237+
time.Sleep(time.Second / 100)
22372238
err = c.WriteMessage(Binary, bytes.Repeat([]byte("1"), 1025))
22382239
if err != nil {
22392240
t.Error(err)
@@ -2245,12 +2246,57 @@ func Test_CommonOption(t *testing.T) {
22452246
defer ts.Close()
22462247

22472248
url := strings.ReplaceAll(ts.URL, "http", "ws")
2248-
con, err := Dial(url, WithClientBufioParseMode(), WithClientReadMaxMessage(1<<10), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) {
2249+
con, err := Dial(url, WithClientCallback(&tsort), WithClientBufioParseMode(), WithClientReadMaxMessage(1<<10))
2250+
if err != nil {
2251+
t.Error(err)
2252+
return
2253+
}
2254+
defer con.Close()
2255+
go func() {
2256+
_ = con.ReadLoop()
2257+
}()
2258+
2259+
select {
2260+
case d := <-tsort.err:
2261+
if d == nil {
2262+
t.Errorf("got:nil, need:error\n")
2263+
}
2264+
case <-time.After(100 * time.Hour):
2265+
t.Errorf(" Test_ServerOption:WithServerReadMaxMessage timeout\n")
2266+
}
2267+
if atomic.LoadInt32(&tsort.run) != 1 {
2268+
t.Error("not run server:method fail")
2269+
}
2270+
})
2271+
t.Run("22.4.WithClientReadMaxMessage-ParseWindows", func(t *testing.T) {
2272+
var tsort testServerOptionReadTimeout
2273+
2274+
upgrade := NewUpgrade()
2275+
tsort.err = make(chan error, 1)
2276+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2277+
c, err := upgrade.Upgrade(w, r)
2278+
if err != nil {
2279+
t.Error(err)
2280+
}
2281+
time.Sleep(time.Second / 100)
2282+
err = c.WriteMessage(Binary, bytes.Repeat([]byte("1"), 1025))
2283+
if err != nil {
2284+
t.Error(err)
2285+
return
2286+
}
2287+
c.StartReadLoop()
22492288
}))
2289+
2290+
defer ts.Close()
2291+
2292+
url := strings.ReplaceAll(ts.URL, "http", "ws")
2293+
con, err := Dial(url, WithClientCallback(&tsort), WithClientReadMaxMessage(1<<10))
22502294
if err != nil {
22512295
t.Error(err)
2296+
return
22522297
}
22532298
defer con.Close()
2299+
con.StartReadLoop()
22542300

22552301
select {
22562302
case d := <-tsort.err:

config.go

+6
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,16 @@ import (
2727

2828
var ErrDialFuncAndProxyFunc = errors.New("dialFunc and proxyFunc can't be set at the same time")
2929

30+
// 握手
3031
type Dialer interface {
3132
Dial(network, addr string) (c net.Conn, err error)
3233
}
3334

35+
// 带超时时间的握手
36+
type DialerTimeout interface {
37+
DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error)
38+
}
39+
3440
// Config的配置,有两个种用法
3541
// 一种是声明一个全局的配置,后面不停使用。
3642
// 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置

conn.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,11 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPa
208208
}
209209
} else {
210210
r := io.Reader(c.br)
211+
var lr io.Reader
211212
if c.readMaxMessage > 0 {
212-
r = limitreader.NewLimitReader(c.br, c.readMaxMessage)
213+
lr = limitreader.NewLimitReader(c.br, c.readMaxMessage)
213214
}
214-
f, err = frame.ReadFrameFromReaderV2(r, headArray, bufioPayload)
215+
f, err = frame.ReadFrameFromReaderV3(r, lr, headArray, bufioPayload)
215216
}
216217
if err != nil {
217218
c.writeAndMaybeOnClose(err)

conn_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ func TestFragmentFrame(t *testing.T) {
608608
select {
609609
case <-data:
610610
atomic.AddInt32(&run, 1)
611-
case <-time.After(500 * time.Hour):
611+
case <-time.After(500 * time.Millisecond):
612612
}
613613

614614
if atomic.LoadInt32(&run) != 1 {

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module github.com/antlabs/quickws
33
go 1.21
44

55
require (
6-
github.com/antlabs/wsutil v0.1.10
6+
github.com/antlabs/wsutil v0.1.11
77
golang.org/x/net v0.23.0
88
)
99

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
github.com/antlabs/wsutil v0.1.10 h1:86p67dG8/iiQ+yZrHVl73OPHGnXfXopFSU0w84fLOdE=
2-
github.com/antlabs/wsutil v0.1.10/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A=
1+
github.com/antlabs/wsutil v0.1.11 h1:bIVZ3Hxdq5ByZKu5OXL/cMtanEw6YlxdtUDiySI77Q0=
2+
github.com/antlabs/wsutil v0.1.11/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A=
33
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
44
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
55
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=

proxy.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,33 @@ import (
2020
"net"
2121
"net/http"
2222
"net/url"
23+
"time"
2324

2425
"github.com/antlabs/wsutil/hostname"
2526
)
2627

2728
type (
28-
dialFunc func(network, addr string) (c net.Conn, err error)
29+
dialFunc func(network, addr string, timeout time.Duration) (c net.Conn, err error)
2930
httpProxy struct {
30-
proxyAddr *url.URL
31-
dial func(network, addr string) (c net.Conn, err error)
31+
proxyAddr *url.URL
32+
dialTimeout func(network, addr string, timeout time.Duration) (c net.Conn, err error)
33+
timeout time.Duration
3234
}
3335
)
3436

35-
var _ Dialer = (*httpProxy)(nil)
37+
var _ DialerTimeout = (*httpProxy)(nil)
3638

3739
func newhttpProxy(u *url.URL, dial dialFunc) *httpProxy {
38-
return &httpProxy{proxyAddr: u, dial: dial}
40+
return &httpProxy{proxyAddr: u, dialTimeout: dial}
3941
}
4042

41-
func (h *httpProxy) Dial(network, addr string) (c net.Conn, err error) {
43+
func (h *httpProxy) DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error) {
4244
if h.proxyAddr == nil {
43-
return h.dial(network, addr)
45+
return h.dialTimeout(network, addr, h.timeout)
4446
}
4547

4648
hostName := hostname.GetHostName(h.proxyAddr)
47-
c, err = h.dial(network, hostName)
49+
c, err = h.dialTimeout(network, hostName, h.timeout)
4850
if err != nil {
4951
return nil, err
5052
}

proxy_test.go

+12-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"net/url"
2323
"strings"
2424
"testing"
25+
"time"
2526
)
2627

2728
type testServer struct {
@@ -133,7 +134,7 @@ func Test_Proxy(t *testing.T) {
133134
func Test_httpProxy_Dial(t *testing.T) {
134135
type fields struct {
135136
proxyAddr *url.URL
136-
dial func(network, addr string) (c net.Conn, err error)
137+
dial func(network, addr string, timeout time.Duration) (c net.Conn, err error)
137138
}
138139
type args struct {
139140
network string
@@ -146,12 +147,12 @@ func Test_httpProxy_Dial(t *testing.T) {
146147
wantC net.Conn
147148
wantErr bool
148149
}{
149-
// TODO: Add test cases.
150+
// 0
150151
{
151152
name: "No proxy address",
152153
fields: fields{
153154
proxyAddr: nil,
154-
dial: func(network, addr string) (c net.Conn, err error) {
155+
dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) {
155156
// Simulate successful dialing
156157
return &net.TCPConn{}, errors.New("fail")
157158
},
@@ -163,11 +164,12 @@ func Test_httpProxy_Dial(t *testing.T) {
163164
wantC: &net.TCPConn{},
164165
wantErr: true,
165166
},
167+
// 1
166168
{
167169
name: "Proxy address",
168170
fields: fields{
169171
proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")},
170-
dial: func(network, addr string) (c net.Conn, err error) {
172+
dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) {
171173
// Simulate successful dialing
172174
return &net.TCPConn{}, errors.New("fail")
173175
},
@@ -179,11 +181,12 @@ func Test_httpProxy_Dial(t *testing.T) {
179181
wantC: &net.TCPConn{},
180182
wantErr: true,
181183
},
184+
// 2
182185
{
183186
name: "Proxy address",
184187
fields: fields{
185188
proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")},
186-
dial: func(network, addr string) (c net.Conn, err error) {
189+
dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) {
187190
// Simulate successful dialing
188191
return &net.TCPConn{}, nil
189192
},
@@ -193,17 +196,16 @@ func Test_httpProxy_Dial(t *testing.T) {
193196
addr: "a.b.c:80",
194197
},
195198
wantC: &net.TCPConn{},
196-
wantErr: true,
199+
wantErr: false,
197200
},
198201
}
199202
for i, tt := range tests {
200203
t.Run(tt.name, func(t *testing.T) {
201204
h := &httpProxy{
202-
proxyAddr: tt.fields.proxyAddr,
203-
dial: tt.fields.dial,
205+
proxyAddr: tt.fields.proxyAddr,
206+
dialTimeout: tt.fields.dial,
204207
}
205-
_, err := h.Dial(tt.args.network, tt.args.addr)
206-
// gotC, err := h.Dial(tt.args.network, tt.args.addr)
208+
_, err := h.dialTimeout(tt.args.network, tt.args.addr, 0)
207209
if (err != nil) != tt.wantErr {
208210
t.Errorf("index:%d, httpProxy.Dial() error = %v, wantErr %v", i, err, tt.wantErr)
209211
return

0 commit comments

Comments
 (0)