Skip to content

Commit f75c3c9

Browse files
authored
Add CheckRedirect callback (#269)
This commit adds a CheckRedirect callback that opamp-go will call before following a redirect from the server it's trying to connect to. Like in net/http, CheckRedirect can be used to observe the request chain that the client is taking while attempting to make a connection. The user can optionally terminate redirect following by returning an error from CheckRedirect. Unlike in net/http, the via parameter for CheckRedirect is a slice of responses. Since the user would have no other way to access these in the context of opamp-go, CheckRedirect makes them available so that users can know exactly what status codes and headers are set in the response. Another small improvement is that the error callback is no longer called when redirecting. This should help to prevent undue error logging by opamp-go consumers. Since the CheckRedirect callback is now available, it also doesn't represent any loss in functionality to opamp-go consumers.
1 parent bfdb952 commit f75c3c9

File tree

7 files changed

+308
-18
lines changed

7 files changed

+308
-18
lines changed

client/httpclient_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@ package client
33
import (
44
"compress/gzip"
55
"context"
6+
"errors"
67
"io"
78
"net/http"
9+
"net/http/httptest"
10+
"net/url"
811
"sync/atomic"
912
"testing"
1013
"time"
1114

1215
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/mock"
1317
"google.golang.org/protobuf/proto"
1418

1519
"github.com/open-telemetry/opamp-go/client/internal"
@@ -223,3 +227,87 @@ func TestHTTPClientStartWithZeroHeartbeatInterval(t *testing.T) {
223227
// Shutdown the Server.
224228
srv.Close()
225229
}
230+
231+
func mockRedirectHTTP(t testing.TB, viaLen int, err error) *checkRedirectMock {
232+
m := &checkRedirectMock{
233+
t: t,
234+
viaLen: viaLen,
235+
http: true,
236+
}
237+
m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err)
238+
return m
239+
}
240+
241+
func TestRedirectHTTP(t *testing.T) {
242+
redirectee := internal.StartMockServer(t)
243+
tests := []struct {
244+
Name string
245+
Redirector *httptest.Server
246+
ExpError bool
247+
MockRedirect *checkRedirectMock
248+
}{
249+
{
250+
Name: "simple redirect",
251+
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
252+
},
253+
{
254+
Name: "check redirect",
255+
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
256+
MockRedirect: mockRedirectHTTP(t, 1, nil),
257+
},
258+
{
259+
Name: "check redirect returns error",
260+
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
261+
MockRedirect: mockRedirectHTTP(t, 1, errors.New("hello")),
262+
ExpError: true,
263+
},
264+
}
265+
266+
for _, test := range tests {
267+
t.Run(test.Name, func(t *testing.T) {
268+
var connectErr atomic.Value
269+
var connected atomic.Value
270+
271+
settings := &types.StartSettings{
272+
Callbacks: types.Callbacks{
273+
OnConnect: func(ctx context.Context) {
274+
connected.Store(1)
275+
},
276+
OnConnectFailed: func(ctx context.Context, err error) {
277+
connectErr.Store(err)
278+
},
279+
},
280+
}
281+
if test.MockRedirect != nil {
282+
settings.Callbacks = types.Callbacks{
283+
OnConnect: func(ctx context.Context) {
284+
connected.Store(1)
285+
},
286+
OnConnectFailed: func(ctx context.Context, err error) {
287+
connectErr.Store(err)
288+
},
289+
CheckRedirect: test.MockRedirect.CheckRedirect,
290+
}
291+
}
292+
reURL, _ := url.Parse(test.Redirector.URL) // err can't be non-nil
293+
settings.OpAMPServerURL = reURL.String()
294+
client := NewHTTP(nil)
295+
prepareClient(t, settings, client)
296+
297+
err := client.Start(context.Background(), *settings)
298+
if err != nil {
299+
t.Fatal(err)
300+
}
301+
defer client.Stop(context.Background())
302+
// Wait for connection to be established.
303+
eventually(t, func() bool {
304+
return connected.Load() != nil || connectErr.Load() != nil
305+
})
306+
if test.ExpError && connectErr.Load() == nil {
307+
t.Error("expected non-nil error")
308+
} else if err := connectErr.Load(); !test.ExpError && err != nil {
309+
t.Fatal(err)
310+
}
311+
})
312+
}
313+
}

client/internal/httpsender.go

+8
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ func (h *HTTPSender) Run(
9898
h.callbacks = callbacks
9999
h.receiveProcessor = newReceivedProcessor(h.logger, callbacks, h, clientSyncedState, packagesStateProvider, capabilities, packageSyncMutex)
100100

101+
// we need to detect if the redirect was ever set, if not, we want default behaviour
102+
if callbacks.CheckRedirect != nil {
103+
h.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
104+
// viaResp only non-nil for ws client
105+
return callbacks.CheckRedirect(req, via, nil)
106+
}
107+
}
108+
101109
for {
102110
pollingTimer := time.NewTimer(time.Millisecond * time.Duration(atomic.LoadInt64(&h.pollingIntervalMs)))
103111
select {

client/types/callbacks.go

+14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package types
22

33
import (
44
"context"
5+
"net/http"
56

67
"github.com/open-telemetry/opamp-go/protobufs"
78
)
@@ -116,6 +117,19 @@ type Callbacks struct {
116117

117118
// OnCommand is called when the Server requests that the connected Agent perform a command.
118119
OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error
120+
121+
// CheckRedirect is called before following a redirect, allowing the client
122+
// the opportunity to observe the redirect chain, and optionally terminate
123+
// following redirects early.
124+
//
125+
// CheckRedirect is intended to be similar, although not exactly equivalent,
126+
// to net/http.Client's CheckRedirect feature. Unlike in net/http, the via
127+
// parameter is a slice of HTTP responses, instead of requests. This gives
128+
// an opportunity to users to know what the exact response headers and
129+
// status were. The request itself can be obtained from the response.
130+
//
131+
// The responses in the via parameter are passed with their bodies closed.
132+
CheckRedirect func(req *http.Request, viaReq []*http.Request, via []*http.Response) error
119133
}
120134

121135
func (c *Callbacks) SetDefaults() {

client/wsclient.go

+74-14
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ type wsClient struct {
4848
// Network connection timeout used for the WebSocket closing handshake.
4949
// This field is currently only modified during testing.
5050
connShutdownTimeout time.Duration
51+
52+
// responseChain is used for the "via" argument in CheckRedirect.
53+
// It is appended to with every redirect followed, and zeroed on a succesful
54+
// connection. responseChain should only be referred to by the goroutine that
55+
// runs tryConnectOnce and its synchronous callees.
56+
responseChain []*http.Response
5157
}
5258

5359
// NewWebSocket creates a new OpAMP Client that uses WebSocket transport.
@@ -151,11 +157,77 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
151157
return c.common.SendCustomMessage(message)
152158
}
153159

160+
func viaReq(resps []*http.Response) []*http.Request {
161+
reqs := make([]*http.Request, 0, len(resps))
162+
for _, resp := range resps {
163+
reqs = append(reqs, resp.Request)
164+
}
165+
return reqs
166+
}
167+
168+
// handleRedirect checks a failed websocket upgrade response for a 3xx response
169+
// and a Location header. If found, it sets the URL to the location found in the
170+
// header so that it is tried on the next retry, instead of the current URL.
171+
func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) error {
172+
// append to the responseChain so that subsequent redirects will have access
173+
c.responseChain = append(c.responseChain, resp)
174+
175+
// very liberal handling of 3xx that largely ignores HTTP semantics
176+
redirect, err := resp.Location()
177+
if err != nil {
178+
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
179+
return err
180+
}
181+
182+
// It's slightly tricky to make CheckRedirect work. The WS HTTP request is
183+
// formed within the websocket library. To work around that, copy the
184+
// previous request, available in the response, and set the URL to the new
185+
// location. It should then result in the same URL that the websocket
186+
// library will form.
187+
nextRequest := resp.Request.Clone(ctx)
188+
nextRequest.URL = redirect
189+
190+
// if CheckRedirect results in an error, it gets returned, terminating
191+
// redirection. As with stdlib, the error is wrapped in url.Error.
192+
if c.common.Callbacks.CheckRedirect != nil {
193+
if err := c.common.Callbacks.CheckRedirect(nextRequest, viaReq(c.responseChain), c.responseChain); err != nil {
194+
return &url.Error{
195+
Op: "Get",
196+
URL: nextRequest.URL.String(),
197+
Err: err,
198+
}
199+
}
200+
}
201+
202+
// rewrite the scheme for the sake of tolerance
203+
if redirect.Scheme == "http" {
204+
redirect.Scheme = "ws"
205+
} else if redirect.Scheme == "https" {
206+
redirect.Scheme = "wss"
207+
}
208+
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)
209+
210+
// Set the URL to the redirect, so that it connects to it on the
211+
// next cycle.
212+
c.url = redirect
213+
214+
return nil
215+
}
216+
154217
// Try to connect once. Returns an error if connection fails and optional retryAfter
155218
// duration to indicate to the caller to retry after the specified time as instructed
156219
// by the Server.
157220
func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) {
158221
var resp *http.Response
222+
var redirecting bool
223+
defer func() {
224+
if err != nil && !redirecting {
225+
c.responseChain = nil
226+
if !c.common.IsStopping() {
227+
c.common.Callbacks.OnConnectFailed(ctx, err)
228+
}
229+
}
230+
}()
159231
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader())
160232
if err != nil {
161233
if !c.common.IsStopping() {
@@ -164,22 +236,10 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna
164236
if resp != nil {
165237
duration := sharedinternal.ExtractRetryAfterHeader(resp)
166238
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
167-
// very liberal handling of 3xx that largely ignores HTTP semantics
168-
redirect, err := resp.Location()
169-
if err != nil {
170-
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
239+
redirecting = true
240+
if err := c.handleRedirect(ctx, resp); err != nil {
171241
return duration, err
172242
}
173-
// rewrite the scheme for the sake of tolerance
174-
if redirect.Scheme == "http" {
175-
redirect.Scheme = "ws"
176-
} else if redirect.Scheme == "https" {
177-
redirect.Scheme = "wss"
178-
}
179-
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)
180-
// Set the URL to the redirect, so that it connects to it on the
181-
// next cycle.
182-
c.url = redirect
183243
} else {
184244
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
185245
}

0 commit comments

Comments
 (0)