diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 166cf904..35e024c7 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -25,11 +25,17 @@ func SetAuthSecret(secret string) error { } } +type CustomClaims struct { + Authorized bool `json:"authorized"` + NodeId string `json:"node_id"` + jwt.StandardClaims +} + func CreateNewToken(nodeId string) (string, error) { - // set claims - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "authorized": true, - "node_id": nodeId, - }) + claims := CustomClaims{ + Authorized: true, + NodeId: nodeId, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString([]byte(authSecret)) } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index c417d282..58ac3ed8 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "github.com/dgrijalva/jwt-go" "github.com/stretchr/testify/assert" "os" "testing" @@ -39,3 +40,25 @@ func TestSetAuthSecret(t *testing.T) { }) } } + +func TestCreateNewToken(t *testing.T) { + tests := []struct { + name string + }{ + {name: "Valid token with claims generated"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + jwtToken, err := CreateNewToken("test-node-1") + assert.NoError(t, err, "Should successfully generate token") + token, err := jwt.ParseWithClaims(jwtToken, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(authSecret), nil + }) + assert.NoError(t, err, "Should successfully parse token") + claims, ok := token.Claims.(*CustomClaims) + assert.True(t, ok, "Should contain custom claims") + assert.Equal(t, "test-node-1", claims.NodeId, "Claims should have nodeId") + }) + } +} + diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go new file mode 100644 index 00000000..7f684a89 --- /dev/null +++ b/internal/auth/middleware.go @@ -0,0 +1,48 @@ +package auth + +import ( + "context" + "fmt" + "github.com/dgrijalva/jwt-go" + "log" + "net/http" + "time" +) + +type ContextKey string + +const RequestContextKey = ContextKey("request") + +type RequestContext struct { + NodeId string + Timestamp time.Time +} + +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jwtToken := r.Header.Get("X-Auth-Header") + + token, err := jwt.ParseWithClaims(jwtToken, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(authSecret), nil + }) + + if err == nil { + if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid { + c := &RequestContext{ + NodeId: claims.NodeId, + Timestamp: time.Now(), + } + ctx := context.WithValue(r.Context(), RequestContextKey, c) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + + log.Println("Unauthorized request:", err) + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("Unauthorized")) + }) +} diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go new file mode 100644 index 00000000..77e3f6a0 --- /dev/null +++ b/internal/auth/middleware_test.go @@ -0,0 +1,47 @@ +package auth + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestAuthMiddleware(t *testing.T) { + validToken, _ := CreateNewToken("test-node-id") + tests := []struct { + name string + token string + status int + mockHandle http.HandlerFunc + }{ + { + name: "Authorized request", + token: validToken, + status: http.StatusOK, + mockHandle: func(writer http.ResponseWriter, request *http.Request) { + r := request.Context().Value(RequestContextKey).(*RequestContext) + assert.Equal(t, r.NodeId, "test-node-id") + assert.NotNil(t, r.Timestamp) + }, + }, + { + name: "Unauthorized request", + token: "invalidtokenstring", + status: http.StatusUnauthorized, + mockHandle: func(writer http.ResponseWriter, request *http.Request) {}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", bytes.NewReader(nil)) + req.Header.Add("X-Auth-Header", test.token) + rr := httptest.NewRecorder() + + handler := AuthMiddleware(test.mockHandle) + handler.ServeHTTP(rr, req) + assert.Equal(t, rr.Code, test.status) + }) + } +} diff --git a/internal/controllers/api.go b/internal/controllers/api.go index cea0746c..61d6aff1 100644 --- a/internal/controllers/api.go +++ b/internal/controllers/api.go @@ -6,10 +6,12 @@ import ( type ApiController struct { nodeRepo models.NodeRepository + pingRepo models.PingRepository } -func NewApiController(nodeRepo models.NodeRepository) *ApiController { +func NewApiController(nodeRepo models.NodeRepository, pingRepo models.PingRepository) *ApiController { return &ApiController{ nodeRepo: nodeRepo, + pingRepo: pingRepo, } } diff --git a/internal/controllers/ping.go b/internal/controllers/ping.go new file mode 100644 index 00000000..d2b2f440 --- /dev/null +++ b/internal/controllers/ping.go @@ -0,0 +1,22 @@ +package controllers + +import ( + "github.com/NodeFactoryIo/vedran/internal/auth" + "github.com/NodeFactoryIo/vedran/internal/models" + "log" + "net/http" +) + +func (c ApiController) PingHandler(w http.ResponseWriter, r *http.Request) { + request := r.Context().Value(auth.RequestContextKey).(*auth.RequestContext) + err := c.pingRepo.Save(&models.Ping{ + NodeId: request.NodeId, + Timestamp: request.Timestamp, + }) + if err != nil { + // error on saving in database + log.Println(err.Error()) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } +} \ No newline at end of file diff --git a/internal/controllers/ping_test.go b/internal/controllers/ping_test.go new file mode 100644 index 00000000..7b3273e8 --- /dev/null +++ b/internal/controllers/ping_test.go @@ -0,0 +1,54 @@ +package controllers + +import ( + "bytes" + "context" + "fmt" + "github.com/NodeFactoryIo/vedran/internal/auth" + "github.com/NodeFactoryIo/vedran/internal/models" + mocks "github.com/NodeFactoryIo/vedran/mocks/models" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestApiController_PingHandler(t *testing.T) { + tests := []struct { + name string + }{ + {name: "Valid ping request"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + timestamp := time.Now() + // create mock controller + nodeRepoMock := mocks.NodeRepository{} + pingRepoMock := mocks.PingRepository{} + pingRepoMock.On("Save", &models.Ping{ + NodeId: "1", + Timestamp: timestamp, + }).Return(nil) + apiController := NewApiController(&nodeRepoMock, &pingRepoMock) + handler := http.HandlerFunc(apiController.PingHandler) + + // create test request and populate context + req, _ := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(nil)) + c := &auth.RequestContext{ + NodeId: "1", + Timestamp: timestamp, + } + ctx := context.WithValue(req.Context(), auth.RequestContextKey, c) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() + + // invoke test request + handler.ServeHTTP(rr, req) + + // asserts + assert.Equal(t, rr.Code, http.StatusOK, fmt.Sprintf("Response status code should be %d", http.StatusOK)) + assert.True(t, pingRepoMock.AssertNumberOfCalls(t, "Save", 1)) + }) + } +} diff --git a/internal/controllers/register_test.go b/internal/controllers/register_test.go index 5cf42edc..1788eda5 100644 --- a/internal/controllers/register_test.go +++ b/internal/controllers/register_test.go @@ -13,7 +13,7 @@ import ( "testing" ) -func TestRegisterHandler(t *testing.T) { +func TestApiController_RegisterHandler(t *testing.T) { // define test cases tests := []struct { name string @@ -39,7 +39,9 @@ func TestRegisterHandler(t *testing.T) { // execute tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { + // create mock controller nodeRepoMock := mocks.NodeRepository{} + pingRepoMock := mocks.PingRepository{} nodeRepoMock.On("Save", &models.Node{ ID: test.registerRequest.Id, ConfigHash: test.registerRequest.ConfigHash, @@ -47,6 +49,9 @@ func TestRegisterHandler(t *testing.T) { PayoutAddress: test.registerRequest.PayoutAddress, Token: test.registerResponse.Token, }).Return(nil) + apiController := NewApiController(&nodeRepoMock, &pingRepoMock) + handler := http.HandlerFunc(apiController.RegisterHandler) + // create test request rb, _ := json.Marshal(test.registerRequest) req, err := http.NewRequest("POST", "/api/v1/node", bytes.NewReader(rb)) @@ -54,12 +59,12 @@ func TestRegisterHandler(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - apiController := NewApiController(&nodeRepoMock) - handler := http.HandlerFunc(apiController.RegisterHandler) + // invoke test request handler.ServeHTTP(rr, req) var response RegisterResponse _ = json.Unmarshal(rr.Body.Bytes(), &response) + // asserts assert.Equal(t, rr.Code, test.httpStatus, fmt.Sprintf("Response status code should be %d", test.httpStatus)) assert.Equal(t, response, test.registerResponse, fmt.Sprintf("Response should be %v", test.registerResponse)) diff --git a/internal/models/node.go b/internal/models/node.go index 9c799276..b9505e9d 100644 --- a/internal/models/node.go +++ b/internal/models/node.go @@ -9,7 +9,7 @@ type Node struct { } type NodeRepository interface { - FindByID(ID int) (*Node, error) + FindByID(ID string) (*Node, error) Save(node *Node) error GetAll() (*[]Node, error) } diff --git a/internal/models/ping.go b/internal/models/ping.go new file mode 100644 index 00000000..3dc0b6af --- /dev/null +++ b/internal/models/ping.go @@ -0,0 +1,14 @@ +package models + +import "time" + +type Ping struct { + NodeId string `storm:"id"` + Timestamp time.Time +} + +type PingRepository interface { + FindByNodeID(nodeId string) (*Ping, error) + Save(ping *Ping) error + GetAll() (*[]Ping, error) +} diff --git a/internal/repositories/node.go b/internal/repositories/node.go index c51e8645..41c4d319 100644 --- a/internal/repositories/node.go +++ b/internal/repositories/node.go @@ -15,7 +15,7 @@ func NewNodeRepo(db *storm.DB) *NodeRepo { } } -func (r *NodeRepo) FindByID(ID int) (*models.Node, error) { +func (r *NodeRepo) FindByID(ID string) (*models.Node, error) { var node *models.Node err := r.db.One("ID", ID, node) return node, err diff --git a/internal/repositories/ping.go b/internal/repositories/ping.go new file mode 100644 index 00000000..7281af1c --- /dev/null +++ b/internal/repositories/ping.go @@ -0,0 +1,32 @@ +package repositories + +import ( + "github.com/NodeFactoryIo/vedran/internal/models" + "github.com/asdine/storm/v3" +) + +type PingRepo struct { + db *storm.DB +} + +func NewPingRepo(db *storm.DB) *PingRepo { + return &PingRepo{ + db: db, + } +} + +func (r *PingRepo) FindByNodeID(nodeId string) (*models.Ping, error) { + var ping *models.Ping + err := r.db.One("NodeId", nodeId, ping) + return ping, err +} + +func (r *PingRepo) Save(ping *models.Ping) error { + return r.db.Save(ping) +} + +func (r PingRepo) GetAll() (*[]models.Ping, error) { + var pings []models.Ping + err := r.db.All(&pings) + return &pings, err +} \ No newline at end of file diff --git a/internal/router/router.go b/internal/router/router.go index 8823f460..406f7359 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,20 +1,34 @@ package router import ( + "github.com/NodeFactoryIo/vedran/internal/auth" "github.com/NodeFactoryIo/vedran/internal/controllers" "github.com/NodeFactoryIo/vedran/internal/repositories" "github.com/asdine/storm/v3" "github.com/gorilla/mux" + "net/http" ) func CreateNewApiRouter(db *storm.DB) *mux.Router { router := mux.NewRouter() // initialize repos nodeRepo := repositories.NewNodeRepo(db) + pingRepo := repositories.NewPingRepo(db) // initialize controllers - apiController := controllers.NewApiController(nodeRepo) + apiController := controllers.NewApiController(nodeRepo, pingRepo) // map controllers handlers to endpoints - router.HandleFunc("/api/v1/nodes", apiController.RegisterHandler).Methods("POST").Name("/api/v1/nodes") - + createRoute("/api/v1/nodes", "POST", apiController.RegisterHandler, router, false) + createRoute("/api/v1/nodes/pings", "POST", apiController.PingHandler, router, true) return router } + +func createRoute(route string, method string, handler http.HandlerFunc, router *mux.Router, authorized bool) { + var r *mux.Route + if authorized { + r = router.Handle(route, auth.AuthMiddleware(handler)) + } else { + r = router.Handle(route, handler) + } + r.Methods(method) + r.Name(route) +} diff --git a/internal/time/time.go b/internal/time/time.go new file mode 100644 index 00000000..677ffca4 --- /dev/null +++ b/internal/time/time.go @@ -0,0 +1,2 @@ +package time + diff --git a/mocks/models/NodeRepository.go b/mocks/models/NodeRepository.go index 970b2e72..1bad3956 100644 --- a/mocks/models/NodeRepository.go +++ b/mocks/models/NodeRepository.go @@ -11,11 +11,11 @@ type NodeRepository struct { } // FindByID provides a mock function with given fields: ID -func (_m *NodeRepository) FindByID(ID int) (*models.Node, error) { +func (_m *NodeRepository) FindByID(ID string) (*models.Node, error) { ret := _m.Called(ID) var r0 *models.Node - if rf, ok := ret.Get(0).(func(int) *models.Node); ok { + if rf, ok := ret.Get(0).(func(string) *models.Node); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { @@ -24,7 +24,7 @@ func (_m *NodeRepository) FindByID(ID int) (*models.Node, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { + if rf, ok := ret.Get(1).(func(string) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) diff --git a/mocks/models/PingRepository.go b/mocks/models/PingRepository.go new file mode 100644 index 00000000..3a4ffca5 --- /dev/null +++ b/mocks/models/PingRepository.go @@ -0,0 +1,71 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import models "github.com/NodeFactoryIo/vedran/internal/models" + +// PingRepository is an autogenerated mock type for the PingRepository type +type PingRepository struct { + mock.Mock +} + +// FindByNodeID provides a mock function with given fields: nodeId +func (_m *PingRepository) FindByNodeID(nodeId string) (*models.Ping, error) { + ret := _m.Called(nodeId) + + var r0 *models.Ping + if rf, ok := ret.Get(0).(func(string) *models.Ping); ok { + r0 = rf(nodeId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Ping) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(nodeId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAll provides a mock function with given fields: +func (_m *PingRepository) GetAll() (*[]models.Ping, error) { + ret := _m.Called() + + var r0 *[]models.Ping + if rf, ok := ret.Get(0).(func() *[]models.Ping); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*[]models.Ping) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Save provides a mock function with given fields: ping +func (_m *PingRepository) Save(ping *models.Ping) error { + ret := _m.Called(ping) + + var r0 error + if rf, ok := ret.Get(0).(func(*models.Ping) error); ok { + r0 = rf(ping) + } else { + r0 = ret.Error(0) + } + + return r0 +}