Skip to content

Commit 3ae7954

Browse files
committed
server: Add authorization by checking against the connection ID in the database
1 parent 5586301 commit 3ae7954

6 files changed

+191
-40
lines changed

pkg/server/db.go

+32-31
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type game struct {
3838
Player common.Disk
3939
}
4040

41-
func getGameAndOpponentAndConnectionIDs(ctx context.Context, args Args, host string) (game, string, []string, error) {
41+
func getGame(ctx context.Context, args Args, host string) (game, string, map[string]string, error) {
4242
// Get the whole item from DynamoDB.
4343
output, err := args.DB.GetItemWithContext(ctx, &dynamodb.GetItemInput{
4444
TableName: aws.String(args.TableName),
@@ -64,39 +64,39 @@ func getGameAndOpponentAndConnectionIDs(ctx context.Context, args Args, host str
6464
return game, "", nil, err
6565
}
6666

67-
// Get just the connection ID values.
68-
var connectionIDs []string
69-
for _, v := range item.Connections {
70-
connectionIDs = append(connectionIDs, v)
71-
}
72-
73-
return game, item.Opponent, connectionIDs, err
67+
return game, item.Opponent, item.Connections, err
7468
}
7569

76-
func updateGame(ctx context.Context, args Args, host string, game game) error {
70+
func updateGame(ctx context.Context, args Args, host string, game game, connName, connID string) error {
7771
gameBytes, err := json.Marshal(&game)
7872
if err != nil {
7973
return err
8074
}
8175

8276
update := expression.Set(expression.Name(attribGame), expression.Value(gameBytes))
77+
condition := expression.Name(attribConnections + "." + connName).Equal(expression.Value(connID))
8378

84-
_, err = updateItem(ctx, args, host, update, false)
79+
_, err = updateItem(ctx, args, host, update, condition, false)
8580
return err
8681
}
8782

88-
func updateGameOpponentSetConnection(ctx context.Context, args Args, host string, game game, opponent, connName, connID string) error {
83+
func createGame(ctx context.Context, args Args, host string, game game, opponent, connName, connID string) error {
8984
gameBytes, err := json.Marshal(&game)
9085
if err != nil {
9186
return err
9287
}
9388

9489
update := expression.
9590
Set(expression.Name(attribGame), expression.Value(gameBytes)).
96-
Set(expression.Name(attribOpponent), expression.Value(opponent)).
9791
Set(expression.Name(attribConnections), expression.Value(map[string]string{connName: connID}))
9892

99-
_, err = updateItem(ctx, args, host, update, false)
93+
if opponent != "" {
94+
update = update.Set(expression.Name(attribOpponent), expression.Value(opponent))
95+
}
96+
97+
condition := expression.Name(attribHost).AttributeNotExists()
98+
99+
_, err = updateItem(ctx, args, host, update, condition, false)
100100
return err
101101
}
102102

@@ -106,7 +106,7 @@ func updateOpponentConnectionGetGameConnectionIDs(ctx context.Context, args Args
106106
Set(expression.Name(attribConnections+"."+connName), expression.Value(connID))
107107
condition := expression.In(expression.Name(attribOpponent), expression.Value(expectedOpponents[0]), expression.Value(expectedOpponents[1]))
108108

109-
output, err := updateItemWithCondition(ctx, args, host, update, condition, true)
109+
output, err := updateItem(ctx, args, host, update, condition, true)
110110
if err != nil {
111111
return game{}, nil, err
112112
}
@@ -158,11 +158,24 @@ func getHostsByOpponent(ctx context.Context, args Args, opponent string) ([]stri
158158
return hosts, nil
159159
}
160160

161-
func deleteGameGetConnectionIDs(ctx context.Context, args Args, host string) ([]string, error) {
161+
func deleteGameGetConnectionIDs(ctx context.Context, args Args, host, connName, connID string) ([]string, error) {
162+
exp, err := expression.NewBuilder().
163+
WithCondition(expression.Or(
164+
expression.Name(attribConnections+"."+connName).Equal(expression.Value(connID)),
165+
expression.Name(attribHost).AttributeNotExists(),
166+
)).
167+
Build()
168+
if err != nil {
169+
return nil, err
170+
}
171+
162172
output, err := args.DB.DeleteItemWithContext(ctx, &dynamodb.DeleteItemInput{
163-
TableName: aws.String(args.TableName),
164-
Key: hostKey(host),
165-
ReturnValues: aws.String(dynamodb.ReturnValueAllOld),
173+
TableName: aws.String(args.TableName),
174+
Key: hostKey(host),
175+
ConditionExpression: exp.Condition(),
176+
ExpressionAttributeNames: exp.Names(),
177+
ExpressionAttributeValues: exp.Values(),
178+
ReturnValues: aws.String(dynamodb.ReturnValueAllOld),
166179
})
167180
if err != nil {
168181
return nil, err
@@ -184,21 +197,9 @@ func deleteGameGetConnectionIDs(ctx context.Context, args Args, host string) ([]
184197
}
185198

186199
// updateItem wraps dynamodb.UpdateItemWithContext.
187-
func updateItem(ctx context.Context, args Args, host string, update expression.UpdateBuilder, returnOldValues bool) (*dynamodb.UpdateItemOutput, error) {
188-
update = update.Set(expression.Name(attribTTL), expression.Value(time.Now().Add(time.Hour).Unix()))
189-
builder := expression.NewBuilder().WithUpdate(update)
190-
return updateItemWithBuilder(ctx, args, host, builder, returnOldValues)
191-
}
192-
193-
// updateItemWithCondition wraps dynamodb.UpdateItemWithContext.
194-
func updateItemWithCondition(ctx context.Context, args Args, host string, update expression.UpdateBuilder, condition expression.ConditionBuilder, returnOldValues bool) (*dynamodb.UpdateItemOutput, error) {
200+
func updateItem(ctx context.Context, args Args, host string, update expression.UpdateBuilder, condition expression.ConditionBuilder, returnOldValues bool) (*dynamodb.UpdateItemOutput, error) {
195201
update = update.Set(expression.Name(attribTTL), expression.Value(time.Now().Add(time.Hour).Unix()))
196202
builder := expression.NewBuilder().WithUpdate(update).WithCondition(condition)
197-
return updateItemWithBuilder(ctx, args, host, builder, returnOldValues)
198-
}
199-
200-
// updateItemWithBuilder wraps dynamodb.UpdateItemWithContext.
201-
func updateItemWithBuilder(ctx context.Context, args Args, host string, builder expression.Builder, returnOldValues bool) (*dynamodb.UpdateItemOutput, error) {
202203
exp, err := builder.Build()
203204
if err != nil {
204205
return nil, err

pkg/server/handle_gameplay.go

+22-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package server
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"log"
78
"os"
@@ -17,11 +18,28 @@ import (
1718
// Handlers for messages pertaining to gameplay.
1819

1920
func handlePlaceDisk(ctx context.Context, req events.APIGatewayWebsocketProxyRequest, args Args, message *messages.PlaceDisk) error {
20-
game, opponent, connectionIDs, err := getGameAndOpponentAndConnectionIDs(ctx, args, message.Host)
21+
game, opponent, connections, err := getGame(ctx, args, message.Host)
2122
if err != nil {
2223
return fmt.Errorf("failed to load game state: %w", err)
2324
}
2425

26+
authorized := false
27+
for k, v := range connections {
28+
if k == message.Nickname && v == req.RequestContext.ConnectionID {
29+
authorized = true
30+
break
31+
}
32+
}
33+
34+
if !authorized {
35+
return errors.New("unauthorized")
36+
}
37+
38+
var connectionIDs []string
39+
for _, v := range connections {
40+
connectionIDs = append(connectionIDs, v)
41+
}
42+
2543
if opponent == "" {
2644
return handlePlaceDiskSolo(ctx, req.RequestContext, args, message, game)
2745
}
@@ -41,7 +59,7 @@ func handlePlaceDiskSolo(ctx context.Context, reqCtx events.APIGatewayWebsocketP
4159
game.Player = game.Player%2 + 1
4260
}
4361

44-
if err := updateGame(ctx, args, message.Host, game); err != nil {
62+
if err := updateGame(ctx, args, message.Host, game, message.Nickname, reqCtx.ConnectionID); err != nil {
4563
return fmt.Errorf("failed to save updated game state: %w", err)
4664
}
4765

@@ -65,7 +83,7 @@ func handlePlaceDiskSolo(ctx context.Context, reqCtx events.APIGatewayWebsocketP
6583
game.Player = 1
6684
}
6785

68-
if err := updateGame(ctx, args, message.Host, game); err != nil {
86+
if err := updateGame(ctx, args, message.Host, game, message.Nickname, reqCtx.ConnectionID); err != nil {
6987
return fmt.Errorf("failed to save updated game state: %w", err)
7088
}
7189

@@ -94,7 +112,7 @@ func handlePlaceDiskMultiplayer(ctx context.Context, reqCtx events.APIGatewayWeb
94112
game.Player = player%2 + 1
95113
}
96114

97-
if err := updateGame(ctx, args, message.Host, game); err != nil {
115+
if err := updateGame(ctx, args, message.Host, game, message.Nickname, reqCtx.ConnectionID); err != nil {
98116
return fmt.Errorf("failed to save updated game state: %w", err)
99117
}
100118

pkg/server/handle_sessions.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ const waiting = "#waiting"
2020
func handleHostGame(ctx context.Context, req events.APIGatewayWebsocketProxyRequest, args Args, message *messages.HostGame) error {
2121
game := newGame()
2222

23-
if err := updateGameOpponentSetConnection(ctx, args, message.Nickname, game, waiting, message.Nickname, req.RequestContext.ConnectionID); err != nil {
23+
if err := createGame(ctx, args, message.Nickname, game, waiting, message.Nickname, req.RequestContext.ConnectionID); err != nil {
2424
return fmt.Errorf("failed to save new game state: %w", err)
2525
}
2626

@@ -31,7 +31,7 @@ func handleStartSoloGame(ctx context.Context, req events.APIGatewayWebsocketProx
3131
game := newGame()
3232
game.Difficulty = message.Difficulty
3333

34-
if err := updateGame(ctx, args, message.Nickname, game); err != nil {
34+
if err := createGame(ctx, args, message.Nickname, game, "", message.Nickname, req.RequestContext.ConnectionID); err != nil {
3535
return fmt.Errorf("failed to save new game state: %w", err)
3636
}
3737

@@ -75,7 +75,7 @@ func handleListOpenGames(ctx context.Context, req events.APIGatewayWebsocketProx
7575
}
7676

7777
func handleLeaveGame(ctx context.Context, req events.APIGatewayWebsocketProxyRequest, args Args, message *messages.LeaveGame) error {
78-
connectionIDs, err := deleteGameGetConnectionIDs(ctx, args, message.Host)
78+
connectionIDs, err := deleteGameGetConnectionIDs(ctx, args, message.Host, message.Nickname, req.RequestContext.ConnectionID)
7979
if err != nil {
8080
return err
8181
}

pkg/server/handler.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func handleMessage(ctx context.Context, req events.APIGatewayWebsocketProxyReque
7272

7373
message := wrapper.Message
7474

75-
log.Printf("Handling message %T", message)
75+
log.Printf("Handling message %T from connection %s", message, req.RequestContext.ConnectionID)
7676

7777
if err := validate.Struct(message); err != nil {
7878
return err

pkg/server/management_api.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func reply(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestCon
3434

3535
func sendMessage(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestContext, args Args, connectionID string, message interface{}) func() error {
3636
return func() error {
37-
log.Printf("Sending message to connection %s", connectionID)
37+
log.Printf("Sending message %T to connection %s", message, connectionID)
3838

3939
data, err := json.Marshal(messages.Wrapper{Message: message})
4040
if err != nil {

0 commit comments

Comments
 (0)