@@ -48,6 +48,12 @@ type wsClient struct {
48
48
// Network connection timeout used for the WebSocket closing handshake.
49
49
// This field is currently only modified during testing.
50
50
connShutdownTimeout time.Duration
51
+
52
+ // responseChain is used for the "via" argument in CheckRedirect.
53
+ // It is appended to with every redirect followed, and zeroed on a succesful
54
+ // connection. responseChain should only be referred to by the goroutine that
55
+ // runs tryConnectOnce and its synchronous callees.
56
+ responseChain []* http.Response
51
57
}
52
58
53
59
// NewWebSocket creates a new OpAMP Client that uses WebSocket transport.
@@ -151,11 +157,77 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
151
157
return c .common .SendCustomMessage (message )
152
158
}
153
159
160
+ func viaReq (resps []* http.Response ) []* http.Request {
161
+ reqs := make ([]* http.Request , 0 , len (resps ))
162
+ for _ , resp := range resps {
163
+ reqs = append (reqs , resp .Request )
164
+ }
165
+ return reqs
166
+ }
167
+
168
+ // handleRedirect checks a failed websocket upgrade response for a 3xx response
169
+ // and a Location header. If found, it sets the URL to the location found in the
170
+ // header so that it is tried on the next retry, instead of the current URL.
171
+ func (c * wsClient ) handleRedirect (ctx context.Context , resp * http.Response ) error {
172
+ // append to the responseChain so that subsequent redirects will have access
173
+ c .responseChain = append (c .responseChain , resp )
174
+
175
+ // very liberal handling of 3xx that largely ignores HTTP semantics
176
+ redirect , err := resp .Location ()
177
+ if err != nil {
178
+ c .common .Logger .Errorf (ctx , "%d redirect, but no valid location: %s" , resp .StatusCode , err )
179
+ return err
180
+ }
181
+
182
+ // It's slightly tricky to make CheckRedirect work. The WS HTTP request is
183
+ // formed within the websocket library. To work around that, copy the
184
+ // previous request, available in the response, and set the URL to the new
185
+ // location. It should then result in the same URL that the websocket
186
+ // library will form.
187
+ nextRequest := resp .Request .Clone (ctx )
188
+ nextRequest .URL = redirect
189
+
190
+ // if CheckRedirect results in an error, it gets returned, terminating
191
+ // redirection. As with stdlib, the error is wrapped in url.Error.
192
+ if c .common .Callbacks .CheckRedirect != nil {
193
+ if err := c .common .Callbacks .CheckRedirect (nextRequest , viaReq (c .responseChain ), c .responseChain ); err != nil {
194
+ return & url.Error {
195
+ Op : "Get" ,
196
+ URL : nextRequest .URL .String (),
197
+ Err : err ,
198
+ }
199
+ }
200
+ }
201
+
202
+ // rewrite the scheme for the sake of tolerance
203
+ if redirect .Scheme == "http" {
204
+ redirect .Scheme = "ws"
205
+ } else if redirect .Scheme == "https" {
206
+ redirect .Scheme = "wss"
207
+ }
208
+ c .common .Logger .Debugf (ctx , "%d redirect to %s" , resp .StatusCode , redirect )
209
+
210
+ // Set the URL to the redirect, so that it connects to it on the
211
+ // next cycle.
212
+ c .url = redirect
213
+
214
+ return nil
215
+ }
216
+
154
217
// Try to connect once. Returns an error if connection fails and optional retryAfter
155
218
// duration to indicate to the caller to retry after the specified time as instructed
156
219
// by the Server.
157
220
func (c * wsClient ) tryConnectOnce (ctx context.Context ) (retryAfter sharedinternal.OptionalDuration , err error ) {
158
221
var resp * http.Response
222
+ var redirecting bool
223
+ defer func () {
224
+ if err != nil && ! redirecting {
225
+ c .responseChain = nil
226
+ if ! c .common .IsStopping () {
227
+ c .common .Callbacks .OnConnectFailed (ctx , err )
228
+ }
229
+ }
230
+ }()
159
231
conn , resp , err := c .dialer .DialContext (ctx , c .url .String (), c .getHeader ())
160
232
if err != nil {
161
233
if ! c .common .IsStopping () {
@@ -164,22 +236,10 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna
164
236
if resp != nil {
165
237
duration := sharedinternal .ExtractRetryAfterHeader (resp )
166
238
if resp .StatusCode >= 300 && resp .StatusCode < 400 {
167
- // very liberal handling of 3xx that largely ignores HTTP semantics
168
- redirect , err := resp .Location ()
169
- if err != nil {
170
- c .common .Logger .Errorf (ctx , "%d redirect, but no valid location: %s" , resp .StatusCode , err )
239
+ redirecting = true
240
+ if err := c .handleRedirect (ctx , resp ); err != nil {
171
241
return duration , err
172
242
}
173
- // rewrite the scheme for the sake of tolerance
174
- if redirect .Scheme == "http" {
175
- redirect .Scheme = "ws"
176
- } else if redirect .Scheme == "https" {
177
- redirect .Scheme = "wss"
178
- }
179
- c .common .Logger .Debugf (ctx , "%d redirect to %s" , resp .StatusCode , redirect )
180
- // Set the URL to the redirect, so that it connects to it on the
181
- // next cycle.
182
- c .url = redirect
183
243
} else {
184
244
c .common .Logger .Errorf (ctx , "Server responded with status=%v" , resp .Status )
185
245
}
0 commit comments