From e7ed5adaaa06e6d54175ef5335e93a43e51fad3d Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 09:24:40 +0200 Subject: [PATCH 01/10] Implement auth middleware and ping endpoint --- internal/auth/auth.go | 16 ++++++++++----- internal/auth/middleware.go | 37 ++++++++++++++++++++++++++++++++++ internal/controllers/api.go | 4 +++- internal/controllers/ping.go | 18 +++++++++++++++++ internal/models/node.go | 2 +- internal/models/ping.go | 14 +++++++++++++ internal/repositories/node.go | 2 +- internal/repositories/ping.go | 32 +++++++++++++++++++++++++++++ internal/router/router.go | 20 +++++++++++++++--- mocks/models/NodeRepository.go | 6 +++--- 10 files changed, 137 insertions(+), 14 deletions(-) create mode 100644 internal/auth/middleware.go create mode 100644 internal/controllers/ping.go create mode 100644 internal/models/ping.go create mode 100644 internal/repositories/ping.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 166cf904..a7feed74 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -25,11 +25,17 @@ func SetAuthSecret(secret string) error { } } +type CustomClaims struct { + Authorized bool `json:"authorized"` + NodeId string `json:"node_id"` + jwt.StandardClaims +} + func CreateNewToken(nodeId string) (string, error) { - // set claims - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "authorized": true, - "node_id": nodeId, - }) + claims := CustomClaims{ + Authorized: true, + NodeId: nodeId, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString([]byte(authSecret)) } diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go new file mode 100644 index 00000000..4b1c588f --- /dev/null +++ b/internal/auth/middleware.go @@ -0,0 +1,37 @@ +package auth + +import ( + "context" + "fmt" + "github.com/dgrijalva/jwt-go" + "net/http" +) + +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jwtToken := r.Header.Get("X-Auth-Token") + + token, err := jwt.ParseWithClaims(jwtToken, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(authSecret), nil + }) + + if token == nil { + // todo + return + } + + if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid { + ctx := context.WithValue(r.Context(), "node-id", claims.NodeId) + // Access context values in handlers like this + // props, _ := r.Context().Value("props").(jwt.MapClaims) + next.ServeHTTP(w, r.WithContext(ctx)) + } else { + fmt.Println(err) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorized")) + } + }) +} \ No newline at end of file diff --git a/internal/controllers/api.go b/internal/controllers/api.go index cea0746c..61d6aff1 100644 --- a/internal/controllers/api.go +++ b/internal/controllers/api.go @@ -6,10 +6,12 @@ import ( type ApiController struct { nodeRepo models.NodeRepository + pingRepo models.PingRepository } -func NewApiController(nodeRepo models.NodeRepository) *ApiController { +func NewApiController(nodeRepo models.NodeRepository, pingRepo models.PingRepository) *ApiController { return &ApiController{ nodeRepo: nodeRepo, + pingRepo: pingRepo, } } diff --git a/internal/controllers/ping.go b/internal/controllers/ping.go new file mode 100644 index 00000000..282e231f --- /dev/null +++ b/internal/controllers/ping.go @@ -0,0 +1,18 @@ +package controllers + +import ( + "github.com/NodeFactoryIo/vedran/internal/models" + "net/http" + "time" +) + +func (c ApiController) PingHandler(w http.ResponseWriter, r *http.Request) { + id := r.Context().Value("node-id").(string) + err := c.pingRepo.Save(&models.Ping{ + NodeId: id, + Timestamp: time.Now(), + }) + if err != nil { + // todo handle + } +} \ No newline at end of file diff --git a/internal/models/node.go b/internal/models/node.go index 9c799276..b9505e9d 100644 --- a/internal/models/node.go +++ b/internal/models/node.go @@ -9,7 +9,7 @@ type Node struct { } type NodeRepository interface { - FindByID(ID int) (*Node, error) + FindByID(ID string) (*Node, error) Save(node *Node) error GetAll() (*[]Node, error) } diff --git a/internal/models/ping.go b/internal/models/ping.go new file mode 100644 index 00000000..3dc0b6af --- /dev/null +++ b/internal/models/ping.go @@ -0,0 +1,14 @@ +package models + +import "time" + +type Ping struct { + NodeId string `storm:"id"` + Timestamp time.Time +} + +type PingRepository interface { + FindByNodeID(nodeId string) (*Ping, error) + Save(ping *Ping) error + GetAll() (*[]Ping, error) +} diff --git a/internal/repositories/node.go b/internal/repositories/node.go index c51e8645..41c4d319 100644 --- a/internal/repositories/node.go +++ b/internal/repositories/node.go @@ -15,7 +15,7 @@ func NewNodeRepo(db *storm.DB) *NodeRepo { } } -func (r *NodeRepo) FindByID(ID int) (*models.Node, error) { +func (r *NodeRepo) FindByID(ID string) (*models.Node, error) { var node *models.Node err := r.db.One("ID", ID, node) return node, err diff --git a/internal/repositories/ping.go b/internal/repositories/ping.go new file mode 100644 index 00000000..7281af1c --- /dev/null +++ b/internal/repositories/ping.go @@ -0,0 +1,32 @@ +package repositories + +import ( + "github.com/NodeFactoryIo/vedran/internal/models" + "github.com/asdine/storm/v3" +) + +type PingRepo struct { + db *storm.DB +} + +func NewPingRepo(db *storm.DB) *PingRepo { + return &PingRepo{ + db: db, + } +} + +func (r *PingRepo) FindByNodeID(nodeId string) (*models.Ping, error) { + var ping *models.Ping + err := r.db.One("NodeId", nodeId, ping) + return ping, err +} + +func (r *PingRepo) Save(ping *models.Ping) error { + return r.db.Save(ping) +} + +func (r PingRepo) GetAll() (*[]models.Ping, error) { + var pings []models.Ping + err := r.db.All(&pings) + return &pings, err +} \ No newline at end of file diff --git a/internal/router/router.go b/internal/router/router.go index 8823f460..406f7359 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,20 +1,34 @@ package router import ( + "github.com/NodeFactoryIo/vedran/internal/auth" "github.com/NodeFactoryIo/vedran/internal/controllers" "github.com/NodeFactoryIo/vedran/internal/repositories" "github.com/asdine/storm/v3" "github.com/gorilla/mux" + "net/http" ) func CreateNewApiRouter(db *storm.DB) *mux.Router { router := mux.NewRouter() // initialize repos nodeRepo := repositories.NewNodeRepo(db) + pingRepo := repositories.NewPingRepo(db) // initialize controllers - apiController := controllers.NewApiController(nodeRepo) + apiController := controllers.NewApiController(nodeRepo, pingRepo) // map controllers handlers to endpoints - router.HandleFunc("/api/v1/nodes", apiController.RegisterHandler).Methods("POST").Name("/api/v1/nodes") - + createRoute("/api/v1/nodes", "POST", apiController.RegisterHandler, router, false) + createRoute("/api/v1/nodes/pings", "POST", apiController.PingHandler, router, true) return router } + +func createRoute(route string, method string, handler http.HandlerFunc, router *mux.Router, authorized bool) { + var r *mux.Route + if authorized { + r = router.Handle(route, auth.AuthMiddleware(handler)) + } else { + r = router.Handle(route, handler) + } + r.Methods(method) + r.Name(route) +} diff --git a/mocks/models/NodeRepository.go b/mocks/models/NodeRepository.go index 970b2e72..1bad3956 100644 --- a/mocks/models/NodeRepository.go +++ b/mocks/models/NodeRepository.go @@ -11,11 +11,11 @@ type NodeRepository struct { } // FindByID provides a mock function with given fields: ID -func (_m *NodeRepository) FindByID(ID int) (*models.Node, error) { +func (_m *NodeRepository) FindByID(ID string) (*models.Node, error) { ret := _m.Called(ID) var r0 *models.Node - if rf, ok := ret.Get(0).(func(int) *models.Node); ok { + if rf, ok := ret.Get(0).(func(string) *models.Node); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { @@ -24,7 +24,7 @@ func (_m *NodeRepository) FindByID(ID int) (*models.Node, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { + if rf, ok := ret.Get(1).(func(string) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) From 7ad5bc331a02ca59840e627087d1ea629426303b Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 10:33:57 +0200 Subject: [PATCH 02/10] Add test for creating auth token --- internal/auth/auth_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index c417d282..2dd41abb 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "github.com/dgrijalva/jwt-go" "github.com/stretchr/testify/assert" "os" "testing" @@ -39,3 +40,15 @@ func TestSetAuthSecret(t *testing.T) { }) } } + +func TestCreateNewToken(t *testing.T) { + jwtToken, err := CreateNewToken("test-node-1") + assert.NoError(t, err, "Should successfully generate token") + token, err := jwt.ParseWithClaims(jwtToken, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(authSecret), nil + }) + assert.NoError(t, err, "Should successfully parse token") + claims, ok := token.Claims.(*CustomClaims) + assert.True(t, ok, "Should contain custom claims") + assert.Equal(t, "test-node-1", claims.NodeId, "Claims should have nodeId") +} From 9648b9fcb3e35b1d9bb0244ccc515c51cda37bd1 Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 11:22:21 +0200 Subject: [PATCH 03/10] Add ping tests and add timestamp to context --- internal/auth/middleware.go | 2 + internal/controllers/ping.go | 3 +- internal/controllers/ping_test.go | 58 ++++++++++++++++++++++ internal/controllers/register_test.go | 3 +- internal/time/time.go | 2 + mocks/models/PingRepository.go | 71 +++++++++++++++++++++++++++ 6 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 internal/controllers/ping_test.go create mode 100644 internal/time/time.go create mode 100644 mocks/models/PingRepository.go diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 4b1c588f..dd3496e6 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/dgrijalva/jwt-go" "net/http" + "time" ) func AuthMiddleware(next http.Handler) http.Handler { @@ -25,6 +26,7 @@ func AuthMiddleware(next http.Handler) http.Handler { if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid { ctx := context.WithValue(r.Context(), "node-id", claims.NodeId) + ctx = context.WithValue(r.Context(), "timestamp", time.Now()) // Access context values in handlers like this // props, _ := r.Context().Value("props").(jwt.MapClaims) next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/internal/controllers/ping.go b/internal/controllers/ping.go index 282e231f..51712e87 100644 --- a/internal/controllers/ping.go +++ b/internal/controllers/ping.go @@ -8,9 +8,10 @@ import ( func (c ApiController) PingHandler(w http.ResponseWriter, r *http.Request) { id := r.Context().Value("node-id").(string) + timestamp := r.Context().Value("timestamp").(time.Time) err := c.pingRepo.Save(&models.Ping{ NodeId: id, - Timestamp: time.Now(), + Timestamp: timestamp, }) if err != nil { // todo handle diff --git a/internal/controllers/ping_test.go b/internal/controllers/ping_test.go new file mode 100644 index 00000000..1a39f097 --- /dev/null +++ b/internal/controllers/ping_test.go @@ -0,0 +1,58 @@ +package controllers + +import ( + "bytes" + "context" + "fmt" + "github.com/NodeFactoryIo/vedran/internal/models" + mocks "github.com/NodeFactoryIo/vedran/mocks/models" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" +) + +func TestPingHandler(t *testing.T) { + // define test cases + tests := []struct { + name string + httpStatus int + }{ + { + name: "Valid ping test", + httpStatus: http.StatusOK, + }, + } + _ = os.Setenv("AUTH_SECRET", "test-auth-secret") + // execute tests + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + timestamp := time.Now() + nodeRepoMock := mocks.NodeRepository{} + pingRepoMock := mocks.PingRepository{} + pingRepoMock.On("Save", &models.Ping{ + NodeId: "1", + Timestamp: timestamp, + }).Return(nil) + + rr := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) + ctx := req.Context() + ctx = context.WithValue(ctx, "node-id", "1") + ctx = context.WithValue(ctx, "timestamp", timestamp) + req = req.WithContext(ctx) + + apiController := NewApiController(&nodeRepoMock, &pingRepoMock) + handler := http.HandlerFunc(apiController.PingHandler) + // invoke test request + handler.ServeHTTP(rr, req) + + // asserts + assert.Equal(t, rr.Code, test.httpStatus, fmt.Sprintf("Response status code should be %d", test.httpStatus)) + assert.True(t, pingRepoMock.AssertNumberOfCalls(t, "Save", 1)) + }) + } + _ = os.Setenv("AUTH_SECRET", "") +} \ No newline at end of file diff --git a/internal/controllers/register_test.go b/internal/controllers/register_test.go index 5cf42edc..79bc8863 100644 --- a/internal/controllers/register_test.go +++ b/internal/controllers/register_test.go @@ -40,6 +40,7 @@ func TestRegisterHandler(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { nodeRepoMock := mocks.NodeRepository{} + pingRepoMock := mocks.PingRepository{} nodeRepoMock.On("Save", &models.Node{ ID: test.registerRequest.Id, ConfigHash: test.registerRequest.ConfigHash, @@ -54,7 +55,7 @@ func TestRegisterHandler(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - apiController := NewApiController(&nodeRepoMock) + apiController := NewApiController(&nodeRepoMock, &pingRepoMock) handler := http.HandlerFunc(apiController.RegisterHandler) // invoke test request handler.ServeHTTP(rr, req) diff --git a/internal/time/time.go b/internal/time/time.go new file mode 100644 index 00000000..677ffca4 --- /dev/null +++ b/internal/time/time.go @@ -0,0 +1,2 @@ +package time + diff --git a/mocks/models/PingRepository.go b/mocks/models/PingRepository.go new file mode 100644 index 00000000..3a4ffca5 --- /dev/null +++ b/mocks/models/PingRepository.go @@ -0,0 +1,71 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import models "github.com/NodeFactoryIo/vedran/internal/models" + +// PingRepository is an autogenerated mock type for the PingRepository type +type PingRepository struct { + mock.Mock +} + +// FindByNodeID provides a mock function with given fields: nodeId +func (_m *PingRepository) FindByNodeID(nodeId string) (*models.Ping, error) { + ret := _m.Called(nodeId) + + var r0 *models.Ping + if rf, ok := ret.Get(0).(func(string) *models.Ping); ok { + r0 = rf(nodeId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Ping) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(nodeId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAll provides a mock function with given fields: +func (_m *PingRepository) GetAll() (*[]models.Ping, error) { + ret := _m.Called() + + var r0 *[]models.Ping + if rf, ok := ret.Get(0).(func() *[]models.Ping); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*[]models.Ping) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Save provides a mock function with given fields: ping +func (_m *PingRepository) Save(ping *models.Ping) error { + ret := _m.Called(ping) + + var r0 error + if rf, ok := ret.Get(0).(func(*models.Ping) error); ok { + r0 = rf(ping) + } else { + r0 = ret.Error(0) + } + + return r0 +} From 1ccb4cdf16d3cf551d89af12592bfcf0375bc4aa Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 11:28:40 +0200 Subject: [PATCH 04/10] Code cleanup in test files --- internal/controllers/ping_test.go | 63 +++++++++++---------------- internal/controllers/register_test.go | 2 +- 2 files changed, 27 insertions(+), 38 deletions(-) diff --git a/internal/controllers/ping_test.go b/internal/controllers/ping_test.go index 1a39f097..a4e96c9f 100644 --- a/internal/controllers/ping_test.go +++ b/internal/controllers/ping_test.go @@ -14,45 +14,34 @@ import ( "time" ) -func TestPingHandler(t *testing.T) { - // define test cases - tests := []struct { - name string - httpStatus int - }{ - { - name: "Valid ping test", - httpStatus: http.StatusOK, - }, - } +func TestApiController_PingHandler(t *testing.T) { _ = os.Setenv("AUTH_SECRET", "test-auth-secret") - // execute tests - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - timestamp := time.Now() - nodeRepoMock := mocks.NodeRepository{} - pingRepoMock := mocks.PingRepository{} - pingRepoMock.On("Save", &models.Ping{ - NodeId: "1", - Timestamp: timestamp, - }).Return(nil) + timestamp := time.Now() - rr := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) - ctx := req.Context() - ctx = context.WithValue(ctx, "node-id", "1") - ctx = context.WithValue(ctx, "timestamp", timestamp) - req = req.WithContext(ctx) + // create mock controller + nodeRepoMock := mocks.NodeRepository{} + pingRepoMock := mocks.PingRepository{} + pingRepoMock.On("Save", &models.Ping{ + NodeId: "1", + Timestamp: timestamp, + }).Return(nil) + apiController := NewApiController(&nodeRepoMock, &pingRepoMock) + handler := http.HandlerFunc(apiController.PingHandler) - apiController := NewApiController(&nodeRepoMock, &pingRepoMock) - handler := http.HandlerFunc(apiController.PingHandler) - // invoke test request - handler.ServeHTTP(rr, req) + // create request and populate context + req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) + ctx := req.Context() + ctx = context.WithValue(ctx, "node-id", "1") + ctx = context.WithValue(ctx, "timestamp", timestamp) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() + + // invoke test request + handler.ServeHTTP(rr, req) + + // asserts + assert.Equal(t, rr.Code, http.StatusOK, fmt.Sprintf("Response status code should be %d", http.StatusOK)) + assert.True(t, pingRepoMock.AssertNumberOfCalls(t, "Save", 1)) - // asserts - assert.Equal(t, rr.Code, test.httpStatus, fmt.Sprintf("Response status code should be %d", test.httpStatus)) - assert.True(t, pingRepoMock.AssertNumberOfCalls(t, "Save", 1)) - }) - } _ = os.Setenv("AUTH_SECRET", "") -} \ No newline at end of file +} diff --git a/internal/controllers/register_test.go b/internal/controllers/register_test.go index 79bc8863..af722061 100644 --- a/internal/controllers/register_test.go +++ b/internal/controllers/register_test.go @@ -13,7 +13,7 @@ import ( "testing" ) -func TestRegisterHandler(t *testing.T) { +func TestApiController_RegisterHandler(t *testing.T) { // define test cases tests := []struct { name string From ff4c8407d2f2282aeb34d6c6cbd8b8efc8c362ec Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 11:31:32 +0200 Subject: [PATCH 05/10] Handle database error in ping handler --- internal/controllers/ping.go | 6 +++++- internal/controllers/ping_test.go | 2 +- internal/controllers/register_test.go | 8 ++++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/internal/controllers/ping.go b/internal/controllers/ping.go index 51712e87..bdda5976 100644 --- a/internal/controllers/ping.go +++ b/internal/controllers/ping.go @@ -2,6 +2,7 @@ package controllers import ( "github.com/NodeFactoryIo/vedran/internal/models" + "log" "net/http" "time" ) @@ -14,6 +15,9 @@ func (c ApiController) PingHandler(w http.ResponseWriter, r *http.Request) { Timestamp: timestamp, }) if err != nil { - // todo handle + // error on saving in database + log.Println(err.Error()) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } } \ No newline at end of file diff --git a/internal/controllers/ping_test.go b/internal/controllers/ping_test.go index a4e96c9f..a8460782 100644 --- a/internal/controllers/ping_test.go +++ b/internal/controllers/ping_test.go @@ -28,7 +28,7 @@ func TestApiController_PingHandler(t *testing.T) { apiController := NewApiController(&nodeRepoMock, &pingRepoMock) handler := http.HandlerFunc(apiController.PingHandler) - // create request and populate context + // create test request and populate context req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) ctx := req.Context() ctx = context.WithValue(ctx, "node-id", "1") diff --git a/internal/controllers/register_test.go b/internal/controllers/register_test.go index af722061..1788eda5 100644 --- a/internal/controllers/register_test.go +++ b/internal/controllers/register_test.go @@ -39,6 +39,7 @@ func TestApiController_RegisterHandler(t *testing.T) { // execute tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { + // create mock controller nodeRepoMock := mocks.NodeRepository{} pingRepoMock := mocks.PingRepository{} nodeRepoMock.On("Save", &models.Node{ @@ -48,6 +49,9 @@ func TestApiController_RegisterHandler(t *testing.T) { PayoutAddress: test.registerRequest.PayoutAddress, Token: test.registerResponse.Token, }).Return(nil) + apiController := NewApiController(&nodeRepoMock, &pingRepoMock) + handler := http.HandlerFunc(apiController.RegisterHandler) + // create test request rb, _ := json.Marshal(test.registerRequest) req, err := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(rb)) @@ -55,12 +59,12 @@ func TestApiController_RegisterHandler(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - apiController := NewApiController(&nodeRepoMock, &pingRepoMock) - handler := http.HandlerFunc(apiController.RegisterHandler) + // invoke test request handler.ServeHTTP(rr, req) var response RegisterResponse _ = json.Unmarshal(rr.Body.Bytes(), &response) + // asserts assert.Equal(t, rr.Code, test.httpStatus, fmt.Sprintf("Response status code should be %d", test.httpStatus)) assert.Equal(t, response, test.registerResponse, fmt.Sprintf("Response should be %v", test.registerResponse)) From 2653f1e64d876a9a2dc66b9cea583b5bff00e4fe Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 13:52:04 +0200 Subject: [PATCH 06/10] Refactor request context and add middleware tests --- internal/auth/auth.go | 4 ++-- internal/auth/middleware.go | 37 ++++++++++++++++++------------ internal/auth/middleware_test.go | 38 +++++++++++++++++++++++++++++++ internal/controllers/ping.go | 9 ++++---- internal/controllers/ping_test.go | 9 +++++--- 5 files changed, 72 insertions(+), 25 deletions(-) create mode 100644 internal/auth/middleware_test.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a7feed74..35e024c7 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -26,8 +26,8 @@ func SetAuthSecret(secret string) error { } type CustomClaims struct { - Authorized bool `json:"authorized"` - NodeId string `json:"node_id"` + Authorized bool `json:"authorized"` + NodeId string `json:"node_id"` jwt.StandardClaims } diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index dd3496e6..870bc809 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -4,10 +4,18 @@ import ( "context" "fmt" "github.com/dgrijalva/jwt-go" + "log" "net/http" "time" ) +const RequestContextKey = "request" + +type RequestContext struct { + NodeId string + Timestamp time.Time +} + func AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { jwtToken := r.Header.Get("X-Auth-Token") @@ -19,21 +27,20 @@ func AuthMiddleware(next http.Handler) http.Handler { return []byte(authSecret), nil }) - if token == nil { - // todo - return + if err == nil { + if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid { + c := &RequestContext{ + NodeId: claims.NodeId, + Timestamp: time.Now(), + } + ctx := context.WithValue(r.Context(), RequestContextKey, c) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } } - if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid { - ctx := context.WithValue(r.Context(), "node-id", claims.NodeId) - ctx = context.WithValue(r.Context(), "timestamp", time.Now()) - // Access context values in handlers like this - // props, _ := r.Context().Value("props").(jwt.MapClaims) - next.ServeHTTP(w, r.WithContext(ctx)) - } else { - fmt.Println(err) - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Unauthorized")) - } + log.Println(err) + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("Unauthorized")) }) -} \ No newline at end of file +} diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go new file mode 100644 index 00000000..0f7810a7 --- /dev/null +++ b/internal/auth/middleware_test.go @@ -0,0 +1,38 @@ +package auth + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestAuthMiddleware_AuthorizedRequest(t *testing.T) { + token, _ := CreateNewToken("test-node-id") + req, _ := http.NewRequest("POST", "/", bytes.NewReader(nil)) + req.Header.Add("X-Auth-Token", token) + rr := httptest.NewRecorder() + + mockHandler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + r := request.Context().Value("request").(*RequestContext) + assert.Equal(t, r.NodeId, "test-node-id") + assert.NotNil(t, r.Timestamp) + }) + + handler := AuthMiddleware(mockHandler) + handler.ServeHTTP(rr, req) +} + +func TestAuthMiddleware_UnauthorizedRequest(t *testing.T) { + token := "invalidtokenstring" + req, _ := http.NewRequest("POST", "/", bytes.NewReader(nil)) + req.Header.Add("X-Auth-Token", token) + rr := httptest.NewRecorder() + + mockHandler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}) + + handler := AuthMiddleware(mockHandler) + handler.ServeHTTP(rr, req) + assert.Equal(t, rr.Code, http.StatusUnauthorized) +} \ No newline at end of file diff --git a/internal/controllers/ping.go b/internal/controllers/ping.go index bdda5976..d2b2f440 100644 --- a/internal/controllers/ping.go +++ b/internal/controllers/ping.go @@ -1,18 +1,17 @@ package controllers import ( + "github.com/NodeFactoryIo/vedran/internal/auth" "github.com/NodeFactoryIo/vedran/internal/models" "log" "net/http" - "time" ) func (c ApiController) PingHandler(w http.ResponseWriter, r *http.Request) { - id := r.Context().Value("node-id").(string) - timestamp := r.Context().Value("timestamp").(time.Time) + request := r.Context().Value(auth.RequestContextKey).(*auth.RequestContext) err := c.pingRepo.Save(&models.Ping{ - NodeId: id, - Timestamp: timestamp, + NodeId: request.NodeId, + Timestamp: request.Timestamp, }) if err != nil { // error on saving in database diff --git a/internal/controllers/ping_test.go b/internal/controllers/ping_test.go index a8460782..de0c598c 100644 --- a/internal/controllers/ping_test.go +++ b/internal/controllers/ping_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/NodeFactoryIo/vedran/internal/auth" "github.com/NodeFactoryIo/vedran/internal/models" mocks "github.com/NodeFactoryIo/vedran/mocks/models" "github.com/stretchr/testify/assert" @@ -30,9 +31,11 @@ func TestApiController_PingHandler(t *testing.T) { // create test request and populate context req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) - ctx := req.Context() - ctx = context.WithValue(ctx, "node-id", "1") - ctx = context.WithValue(ctx, "timestamp", timestamp) + c := &auth.RequestContext{ + NodeId: "1", + Timestamp: timestamp, + } + ctx := context.WithValue(req.Context(), "request", c) req = req.WithContext(ctx) rr := httptest.NewRecorder() From c5d05c573084dac341a2efca195ec8d36d1b2c6c Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 14:00:37 +0200 Subject: [PATCH 07/10] Add context key type --- internal/auth/middleware.go | 4 +++- internal/controllers/ping_test.go | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 870bc809..66f65f15 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -9,7 +9,9 @@ import ( "time" ) -const RequestContextKey = "request" +type ContextKey string + +const RequestContextKey = ContextKey("request") type RequestContext struct { NodeId string diff --git a/internal/controllers/ping_test.go b/internal/controllers/ping_test.go index de0c598c..04a05163 100644 --- a/internal/controllers/ping_test.go +++ b/internal/controllers/ping_test.go @@ -35,7 +35,7 @@ func TestApiController_PingHandler(t *testing.T) { NodeId: "1", Timestamp: timestamp, } - ctx := context.WithValue(req.Context(), "request", c) + ctx := context.WithValue(req.Context(), auth.RequestContextKey, c) req = req.WithContext(ctx) rr := httptest.NewRecorder() From 9a18c62494ce0d82f94b62efe41f11fa7eb5205d Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 15:23:13 +0200 Subject: [PATCH 08/10] Fix failing test --- internal/auth/middleware_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go index 0f7810a7..0e0d44f5 100644 --- a/internal/auth/middleware_test.go +++ b/internal/auth/middleware_test.go @@ -15,7 +15,7 @@ func TestAuthMiddleware_AuthorizedRequest(t *testing.T) { rr := httptest.NewRecorder() mockHandler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - r := request.Context().Value("request").(*RequestContext) + r := request.Context().Value(RequestContextKey).(*RequestContext) assert.Equal(t, r.NodeId, "test-node-id") assert.NotNil(t, r.Timestamp) }) From 2133ad5707b2791013052471c82b30c2a2c119b7 Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 15:47:48 +0200 Subject: [PATCH 09/10] Refactor test to table test --- internal/auth/auth_test.go | 28 +++++++++----- internal/auth/middleware.go | 2 +- internal/auth/middleware_test.go | 61 ++++++++++++++++++------------- internal/controllers/ping_test.go | 60 ++++++++++++++++-------------- 4 files changed, 87 insertions(+), 64 deletions(-) diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 2dd41abb..58ac3ed8 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -42,13 +42,23 @@ func TestSetAuthSecret(t *testing.T) { } func TestCreateNewToken(t *testing.T) { - jwtToken, err := CreateNewToken("test-node-1") - assert.NoError(t, err, "Should successfully generate token") - token, err := jwt.ParseWithClaims(jwtToken, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte(authSecret), nil - }) - assert.NoError(t, err, "Should successfully parse token") - claims, ok := token.Claims.(*CustomClaims) - assert.True(t, ok, "Should contain custom claims") - assert.Equal(t, "test-node-1", claims.NodeId, "Claims should have nodeId") + tests := []struct { + name string + }{ + {name: "Valid token with claims generated"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + jwtToken, err := CreateNewToken("test-node-1") + assert.NoError(t, err, "Should successfully generate token") + token, err := jwt.ParseWithClaims(jwtToken, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(authSecret), nil + }) + assert.NoError(t, err, "Should successfully parse token") + claims, ok := token.Claims.(*CustomClaims) + assert.True(t, ok, "Should contain custom claims") + assert.Equal(t, "test-node-1", claims.NodeId, "Claims should have nodeId") + }) + } } + diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 66f65f15..b7616970 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -41,7 +41,7 @@ func AuthMiddleware(next http.Handler) http.Handler { } } - log.Println(err) + log.Println("Unauthorized request:", err) w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte("Unauthorized")) }) diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go index 0e0d44f5..ca153caf 100644 --- a/internal/auth/middleware_test.go +++ b/internal/auth/middleware_test.go @@ -8,31 +8,40 @@ import ( "testing" ) -func TestAuthMiddleware_AuthorizedRequest(t *testing.T) { - token, _ := CreateNewToken("test-node-id") - req, _ := http.NewRequest("POST", "/", bytes.NewReader(nil)) - req.Header.Add("X-Auth-Token", token) - rr := httptest.NewRecorder() +func TestAuthMiddleware(t *testing.T) { + validToken, _ := CreateNewToken("test-node-id") + tests := []struct { + name string + token string + status int + mockHandle http.HandlerFunc + }{ + { + name: "Authorized request", + token: validToken, + status: http.StatusOK, + mockHandle: func(writer http.ResponseWriter, request *http.Request) { + r := request.Context().Value(RequestContextKey).(*RequestContext) + assert.Equal(t, r.NodeId, "test-node-id") + assert.NotNil(t, r.Timestamp) + }, + }, + { + name: "Unauthorized request", + token: "invalidtokenstring", + status: http.StatusUnauthorized, + mockHandle: func(writer http.ResponseWriter, request *http.Request) {}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", bytes.NewReader(nil)) + req.Header.Add("X-Auth-Token", test.token) + rr := httptest.NewRecorder() - mockHandler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - r := request.Context().Value(RequestContextKey).(*RequestContext) - assert.Equal(t, r.NodeId, "test-node-id") - assert.NotNil(t, r.Timestamp) - }) - - handler := AuthMiddleware(mockHandler) - handler.ServeHTTP(rr, req) + handler := AuthMiddleware(test.mockHandle) + handler.ServeHTTP(rr, req) + assert.Equal(t, rr.Code, test.status) + }) + } } - -func TestAuthMiddleware_UnauthorizedRequest(t *testing.T) { - token := "invalidtokenstring" - req, _ := http.NewRequest("POST", "/", bytes.NewReader(nil)) - req.Header.Add("X-Auth-Token", token) - rr := httptest.NewRecorder() - - mockHandler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}) - - handler := AuthMiddleware(mockHandler) - handler.ServeHTTP(rr, req) - assert.Equal(t, rr.Code, http.StatusUnauthorized) -} \ No newline at end of file diff --git a/internal/controllers/ping_test.go b/internal/controllers/ping_test.go index 04a05163..7b3273e8 100644 --- a/internal/controllers/ping_test.go +++ b/internal/controllers/ping_test.go @@ -10,41 +10,45 @@ import ( "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" - "os" "testing" "time" ) func TestApiController_PingHandler(t *testing.T) { - _ = os.Setenv("AUTH_SECRET", "test-auth-secret") - timestamp := time.Now() - - // create mock controller - nodeRepoMock := mocks.NodeRepository{} - pingRepoMock := mocks.PingRepository{} - pingRepoMock.On("Save", &models.Ping{ - NodeId: "1", - Timestamp: timestamp, - }).Return(nil) - apiController := NewApiController(&nodeRepoMock, &pingRepoMock) - handler := http.HandlerFunc(apiController.PingHandler) - - // create test request and populate context - req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) - c := &auth.RequestContext{ - NodeId: "1", - Timestamp: timestamp, + tests := []struct { + name string + }{ + {name: "Valid ping request"}, } - ctx := context.WithValue(req.Context(), auth.RequestContextKey, c) - req = req.WithContext(ctx) - rr := httptest.NewRecorder() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + timestamp := time.Now() + // create mock controller + nodeRepoMock := mocks.NodeRepository{} + pingRepoMock := mocks.PingRepository{} + pingRepoMock.On("Save", &models.Ping{ + NodeId: "1", + Timestamp: timestamp, + }).Return(nil) + apiController := NewApiController(&nodeRepoMock, &pingRepoMock) + handler := http.HandlerFunc(apiController.PingHandler) - // invoke test request - handler.ServeHTTP(rr, req) + // create test request and populate context + req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) + c := &auth.RequestContext{ + NodeId: "1", + Timestamp: timestamp, + } + ctx := context.WithValue(req.Context(), auth.RequestContextKey, c) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() - // asserts - assert.Equal(t, rr.Code, http.StatusOK, fmt.Sprintf("Response status code should be %d", http.StatusOK)) - assert.True(t, pingRepoMock.AssertNumberOfCalls(t, "Save", 1)) + // invoke test request + handler.ServeHTTP(rr, req) - _ = os.Setenv("AUTH_SECRET", "") + // asserts + assert.Equal(t, rr.Code, http.StatusOK, fmt.Sprintf("Response status code should be %d", http.StatusOK)) + assert.True(t, pingRepoMock.AssertNumberOfCalls(t, "Save", 1)) + }) + } } From e40e03d25044cf1c3c7001217211548e248460d2 Mon Sep 17 00:00:00 2001 From: Mak Muftic Date: Fri, 18 Sep 2020 16:53:12 +0200 Subject: [PATCH 10/10] Fix X-Auth-Header literal --- internal/auth/middleware.go | 2 +- internal/auth/middleware_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index b7616970..7f684a89 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -20,7 +20,7 @@ type RequestContext struct { func AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jwtToken := r.Header.Get("X-Auth-Token") + jwtToken := r.Header.Get("X-Auth-Header") token, err := jwt.ParseWithClaims(jwtToken, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go index ca153caf..77e3f6a0 100644 --- a/internal/auth/middleware_test.go +++ b/internal/auth/middleware_test.go @@ -36,7 +36,7 @@ func TestAuthMiddleware(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { req, _ := http.NewRequest("POST", "/", bytes.NewReader(nil)) - req.Header.Add("X-Auth-Token", test.token) + req.Header.Add("X-Auth-Header", test.token) rr := httptest.NewRecorder() handler := AuthMiddleware(test.mockHandle)