Skip to content

Commit fe887d8

Browse files
authored
add datasouces access in auth middleware (gofr-dev#738)
1 parent 8d09bc1 commit fe887d8

File tree

7 files changed

+163
-45
lines changed

7 files changed

+163
-45
lines changed

docs/advanced-guide/http-authentication/page.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Use `EnableBasicAuthWithFunc(validationFunc)` to implement your own validation l
4343
The `validationFunc` takes the username and password as arguments and returns true if valid, false otherwise.
4444

4545
```go
46-
func validateUser(username string, password string) bool {
46+
func validateUser(c *container.Container, username, password string) bool {
4747
// Implement your credential validation logic here
4848
// This example uses hardcoded credentials for illustration only
4949
return username == "john" && password == "doe123"
@@ -52,7 +52,7 @@ func validateUser(username string, password string) bool {
5252
func main() {
5353
app := gofr.New()
5454

55-
app.EnableBasicAuthWithFunc(validateUser)
55+
app.EnableBasicAuthWithValidator(validateUser)
5656

5757
app.GET("/secure-data", func(c *gofr.Context) (interface{}, error) {
5858
// Handle access to secure data
@@ -102,7 +102,7 @@ func main() {
102102
```go
103103
package main
104104

105-
func apiKeyValidator(apiKey string) bool {
105+
func apiKeyValidator(c *container.Container, apiKey string) bool {
106106
validKeys := []string{"f0e1dffd-0ff0-4ac8-92a3-22d44a1464e4", "d7e4b46e-5b04-47b2-836c-2c7c91250f40"}
107107

108108
return slices.Contains(validKeys, apiKey)
@@ -112,7 +112,7 @@ func main() {
112112
// initialise gofr object
113113
app := gofr.New()
114114

115-
app.EnableAPIKeyAuthWithFunc(apiKeyValidator)
115+
app.EnableAPIKeyAuthWithValidator(apiKeyValidator)
116116

117117
app.GET("/customer", Customer)
118118

pkg/gofr/gofr.go

+23-4
Original file line numberDiff line numberDiff line change
@@ -347,16 +347,35 @@ func (a *App) EnableBasicAuth(credentials ...string) {
347347
a.httpServer.router.Use(middleware.BasicAuthMiddleware(middleware.BasicAuthProvider{Users: users}))
348348
}
349349

350+
// Deprecated: EnableBasicAuthWithFunc is deprecated and will be removed in future releases, users must use
351+
// EnableBasicAuthWithValidator as it has access to application datasources.
350352
func (a *App) EnableBasicAuthWithFunc(validateFunc func(username, password string) bool) {
351-
a.httpServer.router.Use(middleware.BasicAuthMiddleware(middleware.BasicAuthProvider{ValidateFunc: validateFunc}))
353+
a.httpServer.router.Use(middleware.BasicAuthMiddleware(middleware.BasicAuthProvider{ValidateFunc: validateFunc, Container: a.container}))
354+
}
355+
356+
func (a *App) EnableBasicAuthWithValidator(validateFunc func(c *container.Container, username, password string) bool) {
357+
a.httpServer.router.Use(middleware.BasicAuthMiddleware(middleware.BasicAuthProvider{
358+
ValidateFuncWithDatasources: validateFunc, Container: a.container}))
352359
}
353360

354361
func (a *App) EnableAPIKeyAuth(apiKeys ...string) {
355-
a.httpServer.router.Use(middleware.APIKeyAuthMiddleware(nil, apiKeys...))
362+
a.httpServer.router.Use(middleware.APIKeyAuthMiddleware(middleware.APIKeyAuthProvider{}, apiKeys...))
363+
}
364+
365+
// Deprecated: EnableAPIKeyAuthWithFunc is deprecated and will be removed in future releases, users must use
366+
// EnableAPIKeyAuthWithValidator as it has access to application datasources.
367+
func (a *App) EnableAPIKeyAuthWithFunc(validateFunc func(apiKey string) bool) {
368+
a.httpServer.router.Use(middleware.APIKeyAuthMiddleware(middleware.APIKeyAuthProvider{
369+
ValidateFunc: validateFunc,
370+
Container: a.container,
371+
}))
356372
}
357373

358-
func (a *App) EnableAPIKeyAuthWithFunc(validator func(apiKey string) bool) {
359-
a.httpServer.router.Use(middleware.APIKeyAuthMiddleware(validator))
374+
func (a *App) EnableAPIKeyAuthWithValidator(validateFunc func(c *container.Container, apiKey string) bool) {
375+
a.httpServer.router.Use(middleware.APIKeyAuthMiddleware(middleware.APIKeyAuthProvider{
376+
ValidateFuncWithDatasources: validateFunc,
377+
Container: a.container,
378+
}))
360379
}
361380

362381
func (a *App) EnableOAuth(jwksEndpoint string, refreshInterval int) {

pkg/gofr/gofr_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,54 @@ func Test_UseMiddleware(t *testing.T) {
434434
assert.Equal(t, "applied", testHeaderValue, "Test_UseMiddleware Failed! header value mismatch.")
435435
}
436436

437+
func Test_APIKeyAuthMiddleware(t *testing.T) {
438+
c, _ := container.NewMockContainer(t)
439+
440+
app := &App{
441+
httpServer: &httpServer{
442+
router: gofrHTTP.NewRouter(),
443+
port: 8001,
444+
},
445+
container: c,
446+
Config: config.NewMockConfig(map[string]string{"REQUEST_TIMEOUT": "5"}),
447+
}
448+
449+
apiKeys := []string{"test-key"}
450+
validateFunc := func(_ *container.Container, apiKey string) bool {
451+
return apiKey == "test-key"
452+
}
453+
454+
// Registering APIKey middleware with and without custom validator
455+
app.EnableAPIKeyAuth(apiKeys...)
456+
app.EnableAPIKeyAuthWithValidator(validateFunc)
457+
458+
app.GET("/test", func(_ *Context) (interface{}, error) {
459+
return "success", nil
460+
})
461+
462+
go app.Run()
463+
time.Sleep(1 * time.Second)
464+
465+
var netClient = &http.Client{
466+
Timeout: time.Second * 10,
467+
}
468+
469+
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet,
470+
"http://localhost:8001/test", http.NoBody)
471+
req.Header.Set("X-API-Key", "test-key")
472+
473+
// Send the request and check for successful response
474+
resp, err := netClient.Do(req)
475+
if err != nil {
476+
t.Errorf("error while making HTTP request in Test_APIKeyAuthMiddleware. err: %v", err)
477+
return
478+
}
479+
480+
defer resp.Body.Close()
481+
482+
assert.Equal(t, http.StatusOK, resp.StatusCode, "Test_APIKeyAuthMiddleware Failed!")
483+
}
484+
437485
func Test_SwaggerEndpoints(t *testing.T) {
438486
// Create the openapi.json file within the static directory
439487
openAPIFilePath := filepath.Join("static", OpenAPIJSON)

pkg/gofr/http/middleware/apikey_auth.go

+22-11
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@ package middleware
44

55
import (
66
"net/http"
7+
8+
"gofr.dev/pkg/gofr/container"
79
)
810

11+
// APIKeyAuthProvider represents a basic authentication provider.
12+
type APIKeyAuthProvider struct {
13+
ValidateFunc func(apiKey string) bool
14+
ValidateFuncWithDatasources func(c *container.Container, apiKey string) bool
15+
Container *container.Container
16+
}
17+
918
// APIKeyAuthMiddleware creates a middleware function that enforces API key authentication based on the provided API
1019
// keys or a validation function.
11-
func APIKeyAuthMiddleware(validator func(apiKey string) bool, apiKeys ...string) func(handler http.Handler) http.Handler {
20+
func APIKeyAuthMiddleware(a APIKeyAuthProvider, apiKeys ...string) func(handler http.Handler) http.Handler {
1221
return func(handler http.Handler) http.Handler {
1322
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1423
if isWellKnown(r.URL.Path) {
@@ -22,7 +31,7 @@ func APIKeyAuthMiddleware(validator func(apiKey string) bool, apiKeys ...string)
2231
return
2332
}
2433

25-
if !validateKey(validator, authKey, apiKeys...) {
34+
if !validateKey(a, authKey, apiKeys...) {
2635
http.Error(w, "Unauthorized: Invalid Authorization header", http.StatusUnauthorized)
2736
return
2837
}
@@ -42,15 +51,17 @@ func isPresent(authKey string, apiKeys ...string) bool {
4251
return false
4352
}
4453

45-
func validateKey(validator func(apiKey string) bool, authKey string, apiKeys ...string) bool {
46-
if validator != nil {
47-
if !validator(authKey) {
48-
return false
49-
}
50-
} else {
51-
if !isPresent(authKey, apiKeys...) {
52-
return false
53-
}
54+
func validateKey(provider APIKeyAuthProvider, authKey string, apiKeys ...string) bool {
55+
if provider.ValidateFunc != nil && !provider.ValidateFunc(authKey) {
56+
return false
57+
}
58+
59+
if provider.ValidateFuncWithDatasources != nil && !provider.ValidateFuncWithDatasources(provider.Container, authKey) {
60+
return false
61+
}
62+
63+
if provider.ValidateFunc == nil && provider.ValidateFuncWithDatasources == nil {
64+
return isPresent(authKey, apiKeys...)
5465
}
5566

5667
return true

pkg/gofr/http/middleware/apikey_auth_test.go

+34-14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ import (
77
"testing"
88

99
"github.com/stretchr/testify/assert"
10+
11+
"gofr.dev/pkg/gofr/container"
12+
)
13+
14+
const (
15+
validKey1 string = "valid-key-1"
16+
validKey2 string = "valid-key-2"
1017
)
1118

1219
func Test_ApiKeyAuthMiddleware(t *testing.T) {
@@ -15,7 +22,10 @@ func Test_ApiKeyAuthMiddleware(t *testing.T) {
1522
})
1623

1724
validator := func(apiKey string) bool {
18-
return apiKey == "valid-key"
25+
return apiKey == validKey1
26+
}
27+
validatorWithDB := func(_ *container.Container, apiKey string) bool {
28+
return apiKey == validKey2
1929
}
2030

2131
req, err := http.NewRequestWithContext(context.Background(), "GET", "/", http.NoBody)
@@ -24,26 +34,34 @@ func Test_ApiKeyAuthMiddleware(t *testing.T) {
2434
}
2535

2636
testCases := []struct {
27-
desc string
28-
validator func(apiKey string) bool
29-
apiKey string
30-
responseCode int
31-
responseBody string
37+
desc string
38+
validatorFunc func(akiKey string) bool
39+
validatorFuncWithDB func(c *container.Container, apiKey string) bool
40+
apiKey string
41+
responseCode int
42+
responseBody string
3243
}{
33-
{"missing api-key", nil, "", 401, "Unauthorized: Authorization header missing\n"},
34-
{"invalid api-key", nil, "invalid-key", 401, "Unauthorized: Invalid Authorization header\n"},
35-
{"valid api-key", nil, "valid-key-1", 200, "Success"},
36-
{"another valid api-key", nil, "valid-key-2", 200, "Success"},
37-
{"custom validator valid key", validator, "valid-key", 200, "Success"},
38-
{"custom validator in-valid key", validator, "invalid-key", 401, "Unauthorized: Invalid Authorization header\n"},
44+
{"missing api-key", nil, nil, "", 401, "Unauthorized: Authorization header missing\n"},
45+
{"invalid api-key", nil, nil, "invalid-key", 401, "Unauthorized: Invalid Authorization header\n"},
46+
{"valid api-key", nil, nil, validKey1, 200, "Success"},
47+
{"another valid api-key", nil, nil, validKey2, 200, "Success"},
48+
{"custom validatorFunc valid key", validator, nil, validKey1, 200, "Success"},
49+
{"custom validatorFuncWithDB valid key", nil, validatorWithDB, validKey2, 200, "Success"},
50+
{"custom validatorFuncWithDB in-valid key", nil, validatorWithDB, "invalid-key", 401, "Unauthorized: Invalid Authorization header\n"},
3951
}
4052

4153
for i, tc := range testCases {
4254
rr := httptest.NewRecorder()
4355

4456
req.Header.Set("X-API-KEY", tc.apiKey)
4557

46-
wrappedHandler := APIKeyAuthMiddleware(tc.validator, "valid-key-1", "valid-key-2")(testHandler)
58+
provider := APIKeyAuthProvider{
59+
ValidateFunc: tc.validatorFunc,
60+
ValidateFuncWithDatasources: tc.validatorFuncWithDB,
61+
Container: nil,
62+
}
63+
64+
wrappedHandler := APIKeyAuthMiddleware(provider, validKey1, validKey2)(testHandler)
4765
wrappedHandler.ServeHTTP(rr, req)
4866

4967
assert.Equal(t, tc.responseCode, rr.Code, "TEST[%d], Failed.\n%s", i, tc.desc)
@@ -60,7 +78,9 @@ func Test_ApiKeyAuthMiddleware_well_known(t *testing.T) {
6078
req := httptest.NewRequest(http.MethodGet, "/.well-known/health-check", http.NoBody)
6179
rr := httptest.NewRecorder()
6280

63-
wrappedHandler := APIKeyAuthMiddleware(nil)(testHandler)
81+
provider := APIKeyAuthProvider{}
82+
83+
wrappedHandler := APIKeyAuthMiddleware(provider)(testHandler)
6484
wrappedHandler.ServeHTTP(rr, req)
6585

6686
assert.Equal(t, 200, rr.Code, "TEST Failed.\n")

pkg/gofr/http/middleware/basic_auth.go

+16-11
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@ import (
44
"encoding/base64"
55
"net/http"
66
"strings"
7+
8+
"gofr.dev/pkg/gofr/container"
79
)
810

911
const credentialLength = 2
1012

1113
// BasicAuthProvider represents a basic authentication provider.
1214
type BasicAuthProvider struct {
13-
Users map[string]string
14-
ValidateFunc func(username, password string) bool
15+
Users map[string]string
16+
ValidateFunc func(username, password string) bool
17+
ValidateFuncWithDatasources func(c *container.Container, username, password string) bool
18+
Container *container.Container
1519
}
1620

1721
// BasicAuthMiddleware creates a middleware function that enforces basic authentication using the provided BasicAuthProvider.
@@ -58,15 +62,16 @@ func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http.
5862
}
5963

6064
func validateCredentials(provider BasicAuthProvider, credentials []string) bool {
61-
if provider.ValidateFunc != nil {
62-
if !provider.ValidateFunc(credentials[0], credentials[1]) {
63-
return false
64-
}
65-
} else {
66-
if storedPass, ok := provider.Users[credentials[0]]; !ok || storedPass != credentials[1] {
67-
return false
68-
}
65+
if provider.ValidateFunc != nil && !provider.ValidateFunc(credentials[0], credentials[1]) {
66+
return false
67+
}
68+
69+
if provider.ValidateFuncWithDatasources != nil && !provider.ValidateFuncWithDatasources(provider.Container,
70+
credentials[0], credentials[1]) {
71+
return false
6972
}
7073

71-
return true
74+
storedPass, ok := provider.Users[credentials[0]]
75+
76+
return ok && storedPass == credentials[1]
7277
}

pkg/gofr/http/middleware/basic_auth_test.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"testing"
77

88
"github.com/stretchr/testify/assert"
9+
"gofr.dev/pkg/gofr/container"
910
)
1011

1112
func TestBasicAuthMiddleware(t *testing.T) {
@@ -17,6 +18,14 @@ func TestBasicAuthMiddleware(t *testing.T) {
1718
return false
1819
}
1920

21+
validationFuncWithDB := func(_ *container.Container, user, pass string) bool {
22+
if user == "abc" && pass == "pass123" {
23+
return true
24+
}
25+
26+
return false
27+
}
28+
2029
testCases := []struct {
2130
name string
2231
authHeader string
@@ -32,7 +41,7 @@ func TestBasicAuthMiddleware(t *testing.T) {
3241
{
3342
name: "Valid Authorization with validation Func",
3443
authHeader: "Basic YWJjOnBhc3MxMjM=",
35-
authProvider: BasicAuthProvider{ValidateFunc: validationFunc},
44+
authProvider: BasicAuthProvider{Users: map[string]string{"abc": "pass123"}, ValidateFunc: validationFunc},
3645
expectedStatusCode: http.StatusOK,
3746
},
3847
{
@@ -41,6 +50,12 @@ func TestBasicAuthMiddleware(t *testing.T) {
4150
authProvider: BasicAuthProvider{ValidateFunc: validationFunc},
4251
expectedStatusCode: http.StatusUnauthorized,
4352
},
53+
{
54+
name: "false from validation Func with DB",
55+
authHeader: "Basic dXNlcjpwYXNzd29yZA==",
56+
authProvider: BasicAuthProvider{ValidateFuncWithDatasources: validationFuncWithDB},
57+
expectedStatusCode: http.StatusUnauthorized,
58+
},
4459
{
4560
name: "No Authorization Header",
4661
authHeader: "",

0 commit comments

Comments
 (0)