Skip to content

add UpgradeV2 #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion autobahn/config/fuzzingclient.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
{
"outdir": "./report/",
"servers": [
{
"agent": "global",
"url": "ws://localhost:9001/global",
"options": {
"version": 18
}
},
{
"agent": "no-context-takeover-decompression-and-compression-no-tls",
"url": "ws://localhost:9001/no-context-takeover-decompression-and-compression",
Expand Down Expand Up @@ -37,4 +44,4 @@
""
],
"exclude-agent-cases": {}
}
}
19 changes: 19 additions & 0 deletions autobahn/server/autobahn-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,24 @@ func echoReadTime(w http.ResponseWriter, r *http.Request) {
_ = c.ReadLoop()
}

var upgrade = quickws.NewUpgrade(
quickws.WithServerReplyPing(),
quickws.WithServerDecompression(),
quickws.WithServerIgnorePong(),
quickws.WithServerEnableUTF8Check(),
quickws.WithServerReadTimeout(5*time.Second),
)

func global(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.UpgradeV2(w, r, &echoHandler{openWriteTimeout: true})
if err != nil {
fmt.Println("Upgrade fail:", err)
return
}

_ = c.ReadLoop()
}

func startTLSServer(mux *http.ServeMux) {

cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
Expand Down Expand Up @@ -167,6 +185,7 @@ func startServer(mux *http.ServeMux) {
func main() {
mux := &http.ServeMux{}
mux.HandleFunc("/timeout", echoReadTime)
mux.HandleFunc("/global", global)
mux.HandleFunc("/no-context-takeover-decompression", echoNoContextDecompression)
mux.HandleFunc("/no-context-takeover-decompression-and-compression", echoNoContextDecompressionAndCompression)
mux.HandleFunc("/context-takeover-decompression", echoContextTakeoverDecompression)
Expand Down
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,10 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) {
if err := conn.SetDeadline(time.Time{}); err != nil {
return nil, err
}
wsCon = newConn(conn, true /* client is true*/, &d.Config, fr, br)
if wsCon, err = newConn(conn, true /* client is true*/, &d.Config, fr, br); err != nil {
return nil, err
}
wsCon.pd = pd
wsCon.Callback = d.cb
return wsCon, nil
}
16 changes: 8 additions & 8 deletions common_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
// 0. CallbackFunc
func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ClientOption {
return func(o *DialOption) {
o.Callback = &funcToCallback{
o.cb = &funcToCallback{
onOpen: open,
onMessage: m,
onClose: c,
Expand All @@ -34,7 +34,7 @@ func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Cli
// 配置服务端回调函数
func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ServerOption {
return func(o *ConnOption) {
o.Callback = &funcToCallback{
o.cb = &funcToCallback{
onOpen: open,
onMessage: m,
onClose: c,
Expand All @@ -46,14 +46,14 @@ func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Ser
// 配置客户端callback
func WithClientCallback(cb Callback) ClientOption {
return func(o *DialOption) {
o.Callback = cb
o.cb = cb
}
}

// 配置服务端回调函数
func WithServerCallback(cb Callback) ServerOption {
return func(o *ConnOption) {
o.Callback = cb
o.cb = cb
}
}

Expand Down Expand Up @@ -90,14 +90,14 @@ func WithClientEnableUTF8Check() ClientOption {
// 仅仅配置OnMessae函数
func WithServerOnMessageFunc(cb OnMessageFunc) ServerOption {
return func(o *ConnOption) {
o.Callback = OnMessageFunc(cb)
o.cb = OnMessageFunc(cb)
}
}

// 仅仅配置OnMessae函数
func WithClientOnMessageFunc(cb OnMessageFunc) ClientOption {
return func(o *DialOption) {
o.Callback = OnMessageFunc(cb)
o.cb = OnMessageFunc(cb)
}
}

Expand Down Expand Up @@ -292,14 +292,14 @@ func WithClientReadTimeout(t time.Duration) ClientOption {
// 17.1 配置服务端OnClose
func WithServerOnCloseFunc(onClose func(c *Conn, err error)) ServerOption {
return func(o *ConnOption) {
o.Callback = OnCloseFunc(onClose)
o.cb = OnCloseFunc(onClose)
}
}

// 17.2 配置客户端OnClose
func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption {
return func(o *DialOption) {
o.Callback = OnCloseFunc(onClose)
o.cb = OnCloseFunc(onClose)
}
}

Expand Down
4 changes: 2 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type DialerTimeout interface {
// 一种是声明一个全局的配置,后面不停使用。
// 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置
type Config struct {
Callback
cb Callback
deflate.PermessageDeflateConf // 静态配置, 从WithXXX函数中获取
tcpNoDelay bool
replyPing bool // 开启自动回复
Expand All @@ -67,7 +67,7 @@ func (c *Config) initPayloadSize() int {

// 默认设置
func (c *Config) defaultSetting() error {
c.Callback = &DefCallback{}
c.cb = &DefCallback{}
c.maxDelayWriteNum = 10
c.windowsMultipleTimesPayloadSize = 1.0
c.delayWriteInitBufferSize = 8 * 1024
Expand Down
2 changes: 1 addition & 1 deletion config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestConfig_defaultSetting(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Config{
Callback: tt.fields.Callback,
cb: tt.fields.Callback,
tcpNoDelay: tt.fields.tcpNoDelay,
replyPing: tt.fields.replyPing,
ignorePong: tt.fields.ignorePong,
Expand Down
11 changes: 7 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type delayWrite struct {
type Conn struct {
fr fixedreader.FixedReader // 默认使用windows
c net.Conn // net.Conn
Callback // callback移至conn中
br *bufio.Reader // read和fr同时只能使用一个
*Config // config 可能是全局,也可能是局部初始化得来的
pd deflate.PermessageDeflateConf // permessageDeflate局部配置
Expand Down Expand Up @@ -87,18 +88,20 @@ func setNoDelay(c net.Conn, noDelay bool) error {
return nil
}

func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) *Conn {
_ = setNoDelay(c, conf.tcpNoDelay)
func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) (wsCon *Conn, err error) {
if err = setNoDelay(c, conf.tcpNoDelay); err != nil {
return nil, err
}

con := &Conn{
wsCon = &Conn{
c: c,
client: client,
Config: conf,
fr: fr,
br: br,
}

return con
return wsCon, err
}

// 返回标准库的net.Conn
Expand Down
18 changes: 14 additions & 4 deletions upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ func NewUpgrade(opts ...ServerOption) *UpgradeServer {
}

func (u *UpgradeServer) Upgrade(w http.ResponseWriter, r *http.Request) (c *Conn, err error) {
return upgradeInner(w, r, &u.config)
return upgradeInner(w, r, &u.config, nil)
}

func (u *UpgradeServer) UpgradeV2(w http.ResponseWriter, r *http.Request, cb Callback) (c *Conn, err error) {
return upgradeInner(w, r, &u.config, cb)
}

func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) {
Expand All @@ -54,10 +58,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *C
for _, o := range opts {
o(&conf)
}
return upgradeInner(w, r, &conf.Config)
return upgradeInner(w, r, &conf.Config, nil)
}

func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn, err error) {
func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config, cb Callback) (wsCon *Conn, err error) {
if ecode, err := checkRequest(r); err != nil {
http.Error(w, err.Error(), ecode)
return nil, err
Expand Down Expand Up @@ -125,9 +129,15 @@ func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn
if err := conn.SetDeadline(time.Time{}); err != nil {
return nil, err
}
wsCon := newConn(conn, false, conf, fr, br)
if wsCon, err = newConn(conn, false, conf, fr, br); err != nil {
return nil, err
}

wsCon.pd = pd
wsCon.Callback = cb
if cb == nil {
wsCon.Callback = conf.cb
}
return wsCon, nil
}

Expand Down
Loading