From a1ecc998aa96eaf32404204be4d36f8f9f01d26a Mon Sep 17 00:00:00 2001 From: arrlancore Date: Sun, 6 Oct 2024 12:55:14 +0700 Subject: [PATCH] feat: add request and response interceptors to client requests --- axios4go_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++ client.go | 58 +++++++++++++++++++++++++++++++++++------------- 2 files changed, 99 insertions(+), 15 deletions(-) diff --git a/axios4go_test.go b/axios4go_test.go index 480ce99..2299629 100644 --- a/axios4go_test.go +++ b/axios4go_test.go @@ -698,3 +698,59 @@ func TestValidateStatus(t *testing.T) { }) } + +func TestInterceptors(t *testing.T) { + server := setupTestServer() + defer server.Close() + + var interceptedRequest *http.Request + requestInterceptorCalled := false + requestInterceptor := func(req *http.Request) error { + req.Header.Set("X-Intercepted", "true") + interceptedRequest = req + requestInterceptorCalled = true + return nil + } + + responseInterceptor := func(resp *http.Response) error { + resp.Header.Set("X-Intercepted-Response", "true") + return nil + } + + opts := &RequestOptions{ + headers: map[string]string{ + "Content-Type": "application/json", + }, + params: map[string]string{ + "query": "myQuery", + }, + } + + opts.interceptorOptions = InterceptorOptions{ + requestInterceptors: []func(*http.Request) error{requestInterceptor}, + responseInterceptors: []func(*http.Response) error{responseInterceptor}, + } + + t.Run("Interceptors Test", func(t *testing.T) { + response, err := Get(server.URL+"/get", opts) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if !requestInterceptorCalled { + t.Error("Request interceptor was not called") + } + + if interceptedRequest != nil { + if interceptedRequest.Header.Get("X-Intercepted") != "true" { + t.Errorf("Expected request header 'X-Intercepted' to be 'true', got '%s'", interceptedRequest.Header.Get("X-Intercepted")) + } + } else { + t.Error("Intercepted request is nil") + } + + if response.Headers.Get("X-Intercepted-Response") != "true" { + t.Errorf("Expected response header 'X-Intercepted-Response' to be 'true', got '%s'", response.Headers.Get("X-Intercepted-Response")) + } + }) +} diff --git a/client.go b/client.go index 9adbfb9..c7379cc 100644 --- a/client.go +++ b/client.go @@ -35,22 +35,30 @@ type Promise struct { mu sync.Mutex } +type RequestInterceptors []func(*http.Request) error +type ResponseInterceptors []func(*http.Response) error +type InterceptorOptions struct { + requestInterceptors RequestInterceptors + responseInterceptors ResponseInterceptors +} + type RequestOptions struct { - method string - url string - baseURL string - params map[string]string - body interface{} - headers map[string]string - timeout int - auth *auth - responseType string - responseEncoding string - maxRedirects int - maxContentLength int - maxBodyLength int - decompress bool - validateStatus func(int) bool + method string + url string + baseURL string + params map[string]string + body interface{} + headers map[string]string + timeout int + auth *auth + responseType string + responseEncoding string + maxRedirects int + maxContentLength int + maxBodyLength int + decompress bool + validateStatus func(int) bool + interceptorOptions InterceptorOptions } type auth struct { @@ -336,6 +344,13 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) { return nil, err } + for _, interceptor := range options.interceptorOptions.requestInterceptors { + err = interceptor(req) + if err != nil { + return nil, fmt.Errorf("request interceptor failed: %w", err) + } + } + if options.headers == nil { options.headers = make(map[string]string) } @@ -394,6 +409,13 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) { return nil, fmt.Errorf("Request failed with status code: %v", resp.StatusCode) } + for _, interceptor := range options.interceptorOptions.responseInterceptors { + err = interceptor(resp) + if err != nil { + return nil, fmt.Errorf("response interceptor failed: %w", err) + } + } + return &Response{ StatusCode: resp.StatusCode, Headers: resp.Header, @@ -444,6 +466,12 @@ func mergeOptions(dst, src *RequestOptions) { if src.validateStatus != nil { dst.validateStatus = src.validateStatus } + if src.interceptorOptions.requestInterceptors != nil { + dst.interceptorOptions.requestInterceptors = src.interceptorOptions.requestInterceptors + } + if src.interceptorOptions.responseInterceptors != nil { + dst.interceptorOptions.responseInterceptors = src.interceptorOptions.responseInterceptors + } dst.decompress = src.decompress }