@@ -68,6 +68,7 @@ func New(logger types.Logger) *server {
68
68
69
69
func (s * server ) Attach (settings Settings ) (HTTPHandlerFunc , ConnContext , error ) {
70
70
s .settings = settings
71
+ s .settings .Callbacks .SetDefaults ()
71
72
s .wsUpgrader = websocket.Upgrader {
72
73
EnableCompression : settings .EnableCompression ,
73
74
}
@@ -169,26 +170,25 @@ func (s *server) Addr() net.Addr {
169
170
170
171
func (s * server ) httpHandler (w http.ResponseWriter , req * http.Request ) {
171
172
var connectionCallbacks serverTypes.ConnectionCallbacks
172
- if s .settings .Callbacks != nil {
173
- resp := s .settings .Callbacks .OnConnecting (req )
174
- if ! resp .Accept {
175
- // HTTP connection is not accepted. Set the response headers.
176
- for k , v := range resp .HTTPResponseHeader {
177
- w .Header ().Set (k , v )
178
- }
179
- // And write the response status code.
180
- w .WriteHeader (resp .HTTPStatusCode )
181
- return
173
+ resp := s .settings .Callbacks .OnConnecting (req )
174
+ if ! resp .Accept {
175
+ // HTTP connection is not accepted. Set the response headers.
176
+ for k , v := range resp .HTTPResponseHeader {
177
+ w .Header ().Set (k , v )
182
178
}
183
- // use connection-specific handler provided by ConnectionResponse
184
- connectionCallbacks = resp .ConnectionCallbacks
179
+ // And write the response status code.
180
+ w .WriteHeader (resp .HTTPStatusCode )
181
+ return
185
182
}
183
+ // use connection-specific handler provided by ConnectionResponse
184
+ connectionCallbacks = resp .ConnectionCallbacks
185
+ connectionCallbacks .SetDefaults ()
186
186
187
187
// HTTP connection is accepted. Check if it is a plain HTTP request.
188
188
189
189
if req .Header .Get (headerContentType ) == contentTypeProtobuf {
190
190
// Yes, a plain HTTP request.
191
- s .handlePlainHTTPRequest (req , w , connectionCallbacks )
191
+ s .handlePlainHTTPRequest (req , w , & connectionCallbacks )
192
192
return
193
193
}
194
194
@@ -201,10 +201,10 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) {
201
201
202
202
// Return from this func to reduce memory usage.
203
203
// Handle the connection on a separate goroutine.
204
- go s .handleWSConnection (req .Context (), conn , connectionCallbacks )
204
+ go s .handleWSConnection (req .Context (), conn , & connectionCallbacks )
205
205
}
206
206
207
- func (s * server ) handleWSConnection (reqCtx context.Context , wsConn * websocket.Conn , connectionCallbacks serverTypes.ConnectionCallbacks ) {
207
+ func (s * server ) handleWSConnection (reqCtx context.Context , wsConn * websocket.Conn , connectionCallbacks * serverTypes.ConnectionCallbacks ) {
208
208
agentConn := wsConnection {wsConn : wsConn , connMutex : & sync.Mutex {}}
209
209
210
210
defer func () {
@@ -216,14 +216,10 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co
216
216
}
217
217
}()
218
218
219
- if connectionCallbacks != nil {
220
- connectionCallbacks .OnConnectionClose (agentConn )
221
- }
219
+ connectionCallbacks .OnConnectionClose (agentConn )
222
220
}()
223
221
224
- if connectionCallbacks != nil {
225
- connectionCallbacks .OnConnected (reqCtx , agentConn )
226
- }
222
+ connectionCallbacks .OnConnected (reqCtx , agentConn )
227
223
228
224
sentCustomCapabilities := false
229
225
@@ -254,21 +250,19 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co
254
250
continue
255
251
}
256
252
257
- if connectionCallbacks != nil {
258
- response := connectionCallbacks .OnMessage (msgContext , agentConn , & request )
259
- if len (response .InstanceUid ) == 0 {
260
- response .InstanceUid = request .InstanceUid
261
- }
262
- if ! sentCustomCapabilities {
263
- response .CustomCapabilities = & protobufs.CustomCapabilities {
264
- Capabilities : s .settings .CustomCapabilities ,
265
- }
266
- sentCustomCapabilities = true
267
- }
268
- err = agentConn .Send (msgContext , response )
269
- if err != nil {
270
- s .logger .Errorf (msgContext , "Cannot send message to WebSocket: %v" , err )
253
+ response := connectionCallbacks .OnMessage (msgContext , agentConn , & request )
254
+ if len (response .InstanceUid ) == 0 {
255
+ response .InstanceUid = request .InstanceUid
256
+ }
257
+ if ! sentCustomCapabilities {
258
+ response .CustomCapabilities = & protobufs.CustomCapabilities {
259
+ Capabilities : s .settings .CustomCapabilities ,
271
260
}
261
+ sentCustomCapabilities = true
262
+ }
263
+ err = agentConn .Send (msgContext , response )
264
+ if err != nil {
265
+ s .logger .Errorf (msgContext , "Cannot send message to WebSocket: %v" , err )
272
266
}
273
267
}
274
268
}
@@ -310,7 +304,7 @@ func compressGzip(data []byte) ([]byte, error) {
310
304
return buf .Bytes (), nil
311
305
}
312
306
313
- func (s * server ) handlePlainHTTPRequest (req * http.Request , w http.ResponseWriter , connectionCallbacks serverTypes.ConnectionCallbacks ) {
307
+ func (s * server ) handlePlainHTTPRequest (req * http.Request , w http.ResponseWriter , connectionCallbacks * serverTypes.ConnectionCallbacks ) {
314
308
bodyBytes , err := s .readReqBody (req )
315
309
if err != nil {
316
310
s .logger .Debugf (req .Context (), "Cannot read HTTP body: %v" , err )
@@ -331,11 +325,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
331
325
conn : connFromRequest (req ),
332
326
}
333
327
334
- if connectionCallbacks == nil {
335
- w .WriteHeader (http .StatusInternalServerError )
336
- return
337
- }
338
-
339
328
connectionCallbacks .OnConnected (req .Context (), agentConn )
340
329
341
330
defer func () {
0 commit comments