Skip to content

Commit

Permalink
Login fixes, add login to test client
Browse files Browse the repository at this point in the history
  • Loading branch information
DJAndries committed Oct 29, 2024
1 parent db99700 commit a24ec50
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 34 deletions.
4 changes: 2 additions & 2 deletions controllers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ func (ac *AuthController) Router(authMiddleware func(http.Handler) http.Handler)
r := chi.NewRouter()

r.With(authMiddleware).Get("/validate", ac.Validate)
r.With(authMiddleware).Post("/login/init", ac.LoginInit)
r.With(authMiddleware).Post("/login/finalize", ac.LoginInit)
r.Post("/login/init", ac.LoginInit)
r.Post("/login/finalize", ac.LoginFinalize)

return r
}
Expand Down
8 changes: 3 additions & 5 deletions controllers/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ import (
)

type SessionsController struct {
datastore *datastore.Datastore
minSessionVersion int
datastore *datastore.Datastore
}

func NewSessionsController(datastore *datastore.Datastore, minSessionVersion int) *SessionsController {
func NewSessionsController(datastore *datastore.Datastore) *SessionsController {
return &SessionsController{
datastore,
minSessionVersion,
}
}

Expand All @@ -36,7 +34,7 @@ func NewSessionsController(datastore *datastore.Datastore, minSessionVersion int
func (sc *SessionsController) ListSessions(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value(middleware.ContextSession).(*datastore.Session)

sessions, err := sc.datastore.ListSessions(session.AccountID, sc.minSessionVersion)
sessions, err := sc.datastore.ListSessions(session.AccountID, nil)
if err != nil {
util.RenderErrorResponse(w, r, http.StatusInternalServerError, err)
return
Expand Down
5 changes: 3 additions & 2 deletions datastore/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func (d *Datastore) GetOrCreateAccount(email string) (*Account, error) {
var account *Account

err := d.db.Transaction(func(tx *gorm.DB) error {
account, err := d.GetAccount(tx, email)
var err error
account, err = d.GetAccount(tx, email)

if err != nil {
if errors.Is(err, ErrAccountNotFound) {
Expand All @@ -60,7 +61,7 @@ func (d *Datastore) GetOrCreateAccount(email string) (*Account, error) {
Email: email,
}

if err := tx.Create(&account).Error; err != nil {
if err := tx.Create(account).Error; err != nil {
return fmt.Errorf("error creating account: %w", err)
}

Expand Down
9 changes: 5 additions & 4 deletions datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ import (
const databaseURLEnv = "DATABASE_URL"

type Datastore struct {
dbConfig *pgx.ConnConfig
db *gorm.DB
dbConfig *pgx.ConnConfig
db *gorm.DB
newSessionVersion int
}

func NewDatastore() (*Datastore, error) {
func NewDatastore(newSessionVersion int) (*Datastore, error) {
dbURL := os.Getenv(databaseURLEnv)
if dbURL == "" {
return nil, fmt.Errorf("DATABASE_URL environment variable not set")
Expand Down Expand Up @@ -59,5 +60,5 @@ func NewDatastore() (*Datastore, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}

return &Datastore{dbConfig: dbConfig, db: db}, nil
return &Datastore{dbConfig: dbConfig, db: db, newSessionVersion: newSessionVersion}, nil
}
9 changes: 6 additions & 3 deletions datastore/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (d *Datastore) CreateSession(accountID uuid.UUID, sessionName *string) (*Se
ID: id,
AccountID: accountID,
SessionName: sessionName,
Version: 1,
Version: d.newSessionVersion,
}

if err := d.db.Create(&session).Error; err != nil {
Expand All @@ -46,9 +46,12 @@ func (d *Datastore) CreateSession(accountID uuid.UUID, sessionName *string) (*Se
return &session, nil
}

func (d *Datastore) ListSessions(accountID uuid.UUID, minSessionVersion int) ([]Session, error) {
func (d *Datastore) ListSessions(accountID uuid.UUID, minSessionVersion *int) ([]Session, error) {
var sessions []Session
if err := d.db.Where("account_id = ? AND version >= ?", accountID, minSessionVersion).Find(&sessions).Error; err != nil {
if minSessionVersion == nil {
minSessionVersion = &d.newSessionVersion
}
if err := d.db.Where("account_id = ? AND version >= ?", accountID, *minSessionVersion).Find(&sessions).Error; err != nil {
return nil, fmt.Errorf("failed to list sessions: %w", err)
}

Expand Down
14 changes: 7 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ func main() {

passwordAuthEnabled := os.Getenv(passwordAuthEnabledEnv) == "true"

datastore, err := datastore.NewDatastore()
minSessionVersion := 1
if passwordAuthEnabled {
minSessionVersion = 2
}

datastore, err := datastore.NewDatastore(minSessionVersion)
if err != nil {
log.Panic().Err(err).Msg("Failed to init datastore")
}
Expand All @@ -67,11 +72,6 @@ func main() {
log.Panic().Err(err).Msg("Failed to init OPAQUE service")
}

minSessionVersion := 1
if passwordAuthEnabled {
minSessionVersion = 2
}

authMiddleware := middleware.AuthMiddleware(jwtUtil, datastore, minSessionVersion)
verificationAuthMiddleware := middleware.VerificationAuthMiddleware(jwtUtil, datastore)

Expand All @@ -80,7 +80,7 @@ func main() {
authController := controllers.NewAuthController(opaqueService, jwtUtil, datastore)
accountsController := controllers.NewAccountsController(opaqueService, jwtUtil, datastore)
verificationController := controllers.NewVerificationController(datastore, jwtUtil, sesUtil, passwordAuthEnabled)
sessionsController := controllers.NewSessionsController(datastore, minSessionVersion)
sessionsController := controllers.NewSessionsController(datastore)

r.Use(middleware.LoggerMiddleware)

Expand Down
130 changes: 121 additions & 9 deletions misc/test-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ import (
"log"
"net/http"
"os"
"strings"

"github.com/bytemare/opaque"
"github.com/bytemare/opaque/message"
)

func postReq(fields map[string]interface{}, url string, verificationToken string) map[string]interface{} {
var conf = opaque.DefaultConfiguration()

func postReq(fields map[string]interface{}, url string, authToken *string) map[string]interface{} {
jsonBody, err := json.Marshal(fields)
if err != nil {
log.Fatal(err)
Expand All @@ -29,7 +32,9 @@ func postReq(fields map[string]interface{}, url string, verificationToken string
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+verificationToken)
if authToken != nil {
req.Header.Set("Authorization", "Bearer "+*authToken)
}

httpClient := &http.Client{}
resp, err := httpClient.Do(req)
Expand Down Expand Up @@ -57,11 +62,8 @@ func postReq(fields map[string]interface{}, url string, verificationToken string
return respBody
}

func main() {
fmt.Print("Enter verification token: ")
func scanCredentials() (string, string) {
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
verificationToken := scanner.Text()

fmt.Print("Enter email: ")
scanner.Scan()
Expand All @@ -70,8 +72,16 @@ func main() {
fmt.Print("Enter password: ")
scanner.Scan()
password := scanner.Text()
return email, password
}

func register() {
fmt.Print("Enter verification token: ")
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
verificationToken := scanner.Text()

conf := opaque.DefaultConfiguration()
email, password := scanCredentials()

client, err := conf.Client()
if err != nil {
Expand All @@ -87,7 +97,7 @@ func main() {
"blindedMessage": hex.EncodeToString(blindedMessage),
}

resp := postReq(initFields, "http://localhost:8080/v2/accounts/setup/init", verificationToken)
resp := postReq(initFields, "http://localhost:8080/v2/accounts/setup/init", &verificationToken)

evalMsgBytes, err := hex.DecodeString(resp["evaluatedMessage"].(string))
if err != nil {
Expand Down Expand Up @@ -125,7 +135,109 @@ func main() {
"envelope": hex.EncodeToString(record.Envelope),
}

resp = postReq(recordFields, "http://localhost:8080/v2/accounts/setup/finalize", verificationToken)
resp = postReq(recordFields, "http://localhost:8080/v2/accounts/setup/finalize", &verificationToken)

log.Printf("auth token: %v", resp["authToken"])
}

func login() {
email, password := scanCredentials()

client, err := conf.Client()
if err != nil {
log.Fatalln(err)
}

initReq := client.LoginInit([]byte(password))
blindedMessage, err := initReq.BlindedMessage.MarshalBinary()
epk, err := initReq.EpkU.MarshalBinary()
if err != nil {
log.Fatalln(err)
}
initFields := map[string]interface{}{
"email": email,
"blindedMessage": hex.EncodeToString(blindedMessage),
"clientEphemeralPublicKey": hex.EncodeToString(epk),
"clientNonce": hex.EncodeToString(initReq.NonceU),
}

resp := postReq(initFields, "http://localhost:8080/v2/auth/login/init", nil)

evalMsgBytes, err := hex.DecodeString(resp["evaluatedMessage"].(string))
if err != nil {
log.Fatalln(err)
}
maskingNonce, err := hex.DecodeString(resp["maskingNonce"].(string))
if err != nil {
log.Fatalln(err)
}
maskedResponse, err := hex.DecodeString(resp["maskedResponse"].(string))
if err != nil {
log.Fatalln(err)
}
epkBytes, err := hex.DecodeString(resp["serverEphemeralPublicKey"].(string))
if err != nil {
log.Fatalln(err)
}
serverNonce, err := hex.DecodeString(resp["serverNonce"].(string))
if err != nil {
log.Fatalln(err)
}
serverMac, err := hex.DecodeString(resp["serverMac"].(string))
if err != nil {
log.Fatalln(err)
}
evalMsg := conf.OPRF.Group().NewElement()
epkElement := conf.OPRF.Group().NewElement()
if err = evalMsg.UnmarshalBinary(evalMsgBytes); err != nil {
log.Fatalln(err)
}
if err = epkElement.UnmarshalBinary(epkBytes); err != nil {
log.Fatalln(err)
}

opaqueResp := message.KE2{
CredentialResponse: &message.CredentialResponse{
EvaluatedMessage: evalMsg,
MaskingNonce: maskingNonce,
MaskedResponse: maskedResponse,
},
EpkS: epkElement,
NonceS: serverNonce,
Mac: serverMac,
}

ke3, _, err := client.LoginFinish(&opaqueResp, opaque.ClientLoginFinishOptions{
ClientIdentity: []byte(email),
})
if err != nil {
log.Fatalln(err)
}

finalizeFields := map[string]interface{}{
"clientMac": hex.EncodeToString(ke3.Mac),
}
akeToken := resp["akeToken"].(string)
resp = postReq(finalizeFields, "http://localhost:8080/v2/auth/login/finalize", &akeToken)

log.Printf("auth token: %v", resp["authToken"])
}

func main() {
fmt.Println("1. Login")
fmt.Println("2. Register")
fmt.Print("Choose an option (1-2): ")

reader := bufio.NewReader(os.Stdin)
choice, _ := reader.ReadString('\n')
choice = strings.TrimSpace(choice)

switch choice {
case "1":
login()
case "2":
register()
default:
fmt.Println("Invalid option")
}
}
3 changes: 1 addition & 2 deletions services/opaque.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ func (o *OpaqueService) LoginInit(email string, ke1 *opaqueMsg.KE1) (*opaqueMsg.
opaqueRecord = &opaque.ClientRecord{
RegistrationRecord: opaqueRegistration,
CredentialIdentifier: []byte(email),
ClientIdentity: nil,
TestMaskNonce: nil,
ClientIdentity: []byte(email),
}
}

Expand Down

0 comments on commit a24ec50

Please sign in to comment.