Skip to content

Commit 6c0e3aa

Browse files
authored
fix(/commit): protect the commit endpoint via acl (#7608)
/commit endpoint was not ACL protected. In a multi-tenant system, it could be disastrous where a malicious user can commit or abort the transactions of any namespace. This PR partially fixes the issue.
1 parent d5299b9 commit 6c0e3aa

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

dgraph/cmd/alpha/http.go

+11-12
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,11 @@ import (
3434

3535
"github.com/dgraph-io/dgraph/graphql/admin"
3636

37-
"github.com/dgraph-io/dgo/v200"
3837
"github.com/dgraph-io/dgo/v200/protos/api"
3938
"github.com/dgraph-io/dgraph/edgraph"
4039
"github.com/dgraph-io/dgraph/gql"
4140
"github.com/dgraph-io/dgraph/graphql/schema"
4241
"github.com/dgraph-io/dgraph/query"
43-
"github.com/dgraph-io/dgraph/worker"
4442
"github.com/dgraph-io/dgraph/x"
4543
"github.com/gogo/protobuf/jsonpb"
4644
"github.com/golang/glog"
@@ -471,17 +469,18 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
471469
return
472470
}
473471

472+
ctx := x.AttachAccessJwt(context.Background(), r)
474473
var response map[string]interface{}
475474
if abort {
476-
response, err = handleAbort(startTs)
475+
response, err = handleAbort(ctx, startTs)
477476
} else {
478477
// Keys are sent as an array in the body.
479478
reqText := readRequest(w, r)
480479
if reqText == nil {
481480
return
482481
}
483482

484-
response, err = handleCommit(startTs, reqText)
483+
response, err = handleCommit(ctx, startTs, reqText)
485484
}
486485
if err != nil {
487486
x.SetStatus(w, x.ErrorInvalidRequest, err.Error())
@@ -497,27 +496,28 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
497496
_, _ = x.WriteResponse(w, r, js)
498497
}
499498

500-
func handleAbort(startTs uint64) (map[string]interface{}, error) {
499+
func handleAbort(ctx context.Context, startTs uint64) (map[string]interface{}, error) {
501500
tc := &api.TxnContext{
502501
StartTs: startTs,
503502
Aborted: true,
504503
}
505504

506-
_, err := worker.CommitOverNetwork(context.Background(), tc)
507-
switch err {
508-
case dgo.ErrAborted:
505+
tctx, err := (&edgraph.Server{}).CommitOrAbort(ctx, tc)
506+
switch {
507+
case tctx.Aborted:
509508
return map[string]interface{}{
510509
"code": x.Success,
511510
"message": "Done",
512511
}, nil
513-
case nil:
512+
case err == nil:
514513
return nil, errors.Errorf("transaction could not be aborted")
515514
default:
516515
return nil, err
517516
}
518517
}
519518

520-
func handleCommit(startTs uint64, reqText []byte) (map[string]interface{}, error) {
519+
func handleCommit(ctx context.Context, startTs uint64, reqText []byte) (map[string]interface{},
520+
error) {
521521
tc := &api.TxnContext{
522522
StartTs: startTs,
523523
}
@@ -540,14 +540,13 @@ func handleCommit(startTs uint64, reqText []byte) (map[string]interface{}, error
540540
tc.Preds = reqMap["preds"]
541541
}
542542

543-
cts, err := worker.CommitOverNetwork(context.Background(), tc)
543+
tc, err := (&edgraph.Server{}).CommitOrAbort(ctx, tc)
544544
if err != nil {
545545
return nil, err
546546
}
547547

548548
resp := &api.Response{}
549549
resp.Txn = tc
550-
resp.Txn.CommitTs = cts
551550
e := query.Extensions{
552551
Txn: resp.Txn,
553552
}

edgraph/server.go

+34-1
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,33 @@ func authorizeRequest(ctx context.Context, qc *queryContext) error {
14711471
return nil
14721472
}
14731473

1474+
func validateNamespace(ctx context.Context, preds []string) error {
1475+
ns, err := x.ExtractJWTNamespace(ctx)
1476+
if err != nil {
1477+
return err
1478+
}
1479+
1480+
// Do a basic validation that all the predicates passed in transaction context matches the
1481+
// claimed namespace and user is not accidently commiting a transaction that it did not create.
1482+
for _, pred := range preds {
1483+
// Format for Preds in TxnContext is gid-<namespace><pred> (see fillPreds in posting pkg)
1484+
splits := strings.Split(pred, "-")
1485+
if len(splits) < 2 {
1486+
return errors.Errorf("Unable to find group id in %s", pred)
1487+
}
1488+
pred = strings.Join(splits[1:], "-")
1489+
if len(pred) < 8 {
1490+
return errors.Errorf("found invalid pred %s of length < 8 in transaction context", pred)
1491+
}
1492+
if parsedNs := x.ParseNamespace(pred); parsedNs != ns {
1493+
return errors.Errorf("Please login into correct namespace. "+
1494+
"Currently logged in namespace %#x", ns)
1495+
}
1496+
}
1497+
1498+
return nil
1499+
}
1500+
14741501
// CommitOrAbort commits or aborts a transaction.
14751502
func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) {
14761503
ctx, span := otrace.StartSpan(ctx, "Server.CommitOrAbort")
@@ -1480,6 +1507,12 @@ func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.Tx
14801507
return &api.TxnContext{}, err
14811508
}
14821509

1510+
if x.WorkerConfig.AclEnabled {
1511+
if err := validateNamespace(ctx, tc.Preds); err != nil {
1512+
return &api.TxnContext{}, err
1513+
}
1514+
}
1515+
14831516
tctx := &api.TxnContext{}
14841517
if tc.StartTs == 0 {
14851518
return &api.TxnContext{}, errors.Errorf(
@@ -1492,11 +1525,11 @@ func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.Tx
14921525
if err == dgo.ErrAborted {
14931526
// If err returned is dgo.ErrAborted and tc.Aborted was set, that means the client has
14941527
// aborted the transaction by calling txn.Discard(). Hence return a nil error.
1528+
tctx.Aborted = true
14951529
if tc.Aborted {
14961530
return tctx, nil
14971531
}
14981532

1499-
tctx.Aborted = true
15001533
return tctx, status.Errorf(codes.Aborted, err.Error())
15011534
}
15021535
tctx.StartTs = tc.StartTs

0 commit comments

Comments
 (0)