diff --git a/internal/datastore/crdb/readwrite.go b/internal/datastore/crdb/readwrite.go index acfaf97709..43e914cf86 100644 --- a/internal/datastore/crdb/readwrite.go +++ b/internal/datastore/crdb/readwrite.go @@ -386,7 +386,7 @@ func exactRelationshipClause(r tuple.Relationship) sq.Eq { } } -func (rwt *crdbReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (bool, error) { +func (rwt *crdbReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { // Add clauses for the ResourceFilter query := rwt.queryDeleteTuples() @@ -401,7 +401,7 @@ func (rwt *crdbReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1 } if filter.OptionalResourceIdPrefix != "" { if strings.Contains(filter.OptionalResourceIdPrefix, "%") { - return false, fmt.Errorf("unable to delete relationships with a prefix containing the %% character") + return 0, false, fmt.Errorf("unable to delete relationships with a prefix containing the %% character") } query = query.Where(sq.Like{colObjectID: filter.OptionalResourceIdPrefix + "%"}) @@ -434,24 +434,24 @@ func (rwt *crdbReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1 sql, args, err := query.ToSql() if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } modified, err := rwt.tx.Exec(ctx, sql, args...) if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } rwt.relCountChange -= modified.RowsAffected() rowsAffected, err := safecast.ToUint64(modified.RowsAffected()) if err != nil { - return false, spiceerrors.MustBugf("could not cast RowsAffected to uint64: %v", err) + return 0, false, spiceerrors.MustBugf("could not cast RowsAffected to uint64: %v", err) } if delLimit > 0 && rowsAffected == delLimit { - return true, nil + return rowsAffected, true, nil } - return false, nil + return rowsAffected, false, nil } func (rwt *crdbReadWriteTXN) WriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { diff --git a/internal/datastore/memdb/readwrite.go b/internal/datastore/memdb/readwrite.go index 4add605c2a..8929e84ca7 100644 --- a/internal/datastore/memdb/readwrite.go +++ b/internal/datastore/memdb/readwrite.go @@ -134,13 +134,13 @@ func (rwt *memdbReadWriteTx) toCaveatReference(mutation tuple.RelationshipUpdate return cr } -func (rwt *memdbReadWriteTx) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (bool, error) { +func (rwt *memdbReadWriteTx) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { rwt.mustLock() defer rwt.Unlock() tx, err := rwt.txSource() if err != nil { - return false, err + return 0, false, err } delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...) @@ -153,16 +153,16 @@ func (rwt *memdbReadWriteTx) DeleteRelationships(_ context.Context, filter *v1.R } // caller must already hold the concurrent access lock -func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.RelationshipFilter, limit uint64) (bool, error) { +func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.RelationshipFilter, limit uint64) (uint64, bool, error) { // Create an iterator to find the relevant tuples dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter) if err != nil { - return false, err + return 0, false, err } bestIter, err := iteratorForFilter(tx, dsFilter) if err != nil { - return false, err + return 0, false, err } filteredIter := memdb.NewFilterIterator(bestIter, relationshipFilterFilterFunc(filter)) @@ -174,7 +174,7 @@ func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.Relationsh for row := filteredIter.Next(); row != nil; row = filteredIter.Next() { rt, err := row.(*relationship).Relationship() if err != nil { - return false, err + return 0, false, err } mutations = append(mutations, tuple.Delete(rt)) counter++ @@ -185,7 +185,7 @@ func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.Relationsh } } - return metLimit, rwt.write(tx, mutations...) + return counter, metLimit, rwt.write(tx, mutations...) } func (rwt *memdbReadWriteTx) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error { @@ -320,7 +320,7 @@ func (rwt *memdbReadWriteTx) DeleteNamespaces(_ context.Context, nsNames ...stri } // Delete the relationships from the namespace - if _, err := rwt.deleteWithLock(tx, &v1.RelationshipFilter{ + if _, _, err := rwt.deleteWithLock(tx, &v1.RelationshipFilter{ ResourceType: nsName, }, 0); err != nil { return fmt.Errorf("unable to delete relationships from deleted namespace: %w", err) diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index d8b5b40647..b2fff1dcc3 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -339,7 +339,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations return nil } -func (rwt *mysqlReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (bool, error) { +func (rwt *mysqlReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { // Add clauses for the ResourceFilter query := rwt.DeleteRelsQuery if filter.ResourceType != "" { @@ -353,7 +353,7 @@ func (rwt *mysqlReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v } if filter.OptionalResourceIdPrefix != "" { if strings.Contains(filter.OptionalResourceIdPrefix, "%") { - return false, fmt.Errorf("unable to delete relationships with a prefix containing the %% character") + return 0, false, fmt.Errorf("unable to delete relationships with a prefix containing the %% character") } query = query.Where(sq.Like{colObjectID: filter.OptionalResourceIdPrefix + "%"}) @@ -385,29 +385,29 @@ func (rwt *mysqlReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v querySQL, args, err := query.ToSql() if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } modified, err := rwt.tx.ExecContext(ctx, querySQL, args...) if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } rowsAffected, err := modified.RowsAffected() if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } uintRowsAffected, err := safecast.ToUint64(rowsAffected) if err != nil { - return false, spiceerrors.MustBugf("rowsAffected was negative: %v", err) + return 0, false, spiceerrors.MustBugf("rowsAffected was negative: %v", err) } if delLimit > 0 && uintRowsAffected == delLimit { - return true, nil + return uintRowsAffected, true, nil } - return false, nil + return uintRowsAffected, false, nil } func (rwt *mysqlReadWriteTXN) WriteNamespaces(ctx context.Context, newNamespaces ...*core.NamespaceDefinition) error { diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index e135b158ad..e7ed0d7347 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -419,20 +419,21 @@ func handleWriteError(err error) error { return fmt.Errorf(errUnableToWriteRelationships, err) } -func (rwt *pgReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (bool, error) { +func (rwt *pgReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...) if delOpts.DeleteLimit != nil && *delOpts.DeleteLimit > 0 { return rwt.deleteRelationshipsWithLimit(ctx, filter, *delOpts.DeleteLimit) } - return false, rwt.deleteRelationships(ctx, filter) + numDeleted, err := rwt.deleteRelationships(ctx, filter) + return numDeleted, false, err } -func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, filter *v1.RelationshipFilter, limit uint64) (bool, error) { +func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, filter *v1.RelationshipFilter, limit uint64) (uint64, bool, error) { // validate the limit intLimit, err := safecast.ToInt64(limit) if err != nil { - return false, fmt.Errorf("limit argument could not safely be cast to int64: %w", err) + return 0, false, fmt.Errorf("limit argument could not safely be cast to int64: %w", err) } // Construct a select query for the relationships to be removed. @@ -449,7 +450,7 @@ func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, fil } if filter.OptionalResourceIdPrefix != "" { if strings.Contains(filter.OptionalResourceIdPrefix, "%") { - return false, fmt.Errorf("unable to delete relationships with a prefix containing the %% character") + return 0, false, fmt.Errorf("unable to delete relationships with a prefix containing the %% character") } query = query.Where(sq.Like{colObjectID: filter.OptionalResourceIdPrefix + "%"}) @@ -470,7 +471,7 @@ func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, fil selectSQL, args, err := query.ToSql() if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } args = append(args, rwt.newXID) @@ -493,13 +494,18 @@ func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, fil result, err := rwt.tx.Exec(ctx, cteSQL, args...) if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } - return result.RowsAffected() == intLimit, nil + numDeleted, err := safecast.ToUint64(result.RowsAffected()) + if err != nil { + return 0, false, fmt.Errorf("unable to cast rows affected to uint64: %w", err) + } + + return numDeleted, result.RowsAffected() == intLimit, nil } -func (rwt *pgReadWriteTXN) deleteRelationships(ctx context.Context, filter *v1.RelationshipFilter) error { +func (rwt *pgReadWriteTXN) deleteRelationships(ctx context.Context, filter *v1.RelationshipFilter) (uint64, error) { // Add clauses for the ResourceFilter query := deleteTuple if filter.ResourceType != "" { @@ -513,7 +519,7 @@ func (rwt *pgReadWriteTXN) deleteRelationships(ctx context.Context, filter *v1.R } if filter.OptionalResourceIdPrefix != "" { if strings.Contains(filter.OptionalResourceIdPrefix, "%") { - return fmt.Errorf("unable to delete relationships with a prefix containing the %% character") + return 0, fmt.Errorf("unable to delete relationships with a prefix containing the %% character") } query = query.Where(sq.Like{colObjectID: filter.OptionalResourceIdPrefix + "%"}) @@ -532,14 +538,20 @@ func (rwt *pgReadWriteTXN) deleteRelationships(ctx context.Context, filter *v1.R sql, args, err := query.Set(colDeletedXid, rwt.newXID).ToSql() if err != nil { - return fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, fmt.Errorf(errUnableToDeleteRelationships, err) + } + + result, err := rwt.tx.Exec(ctx, sql, args...) + if err != nil { + return 0, fmt.Errorf(errUnableToDeleteRelationships, err) } - if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { - return fmt.Errorf(errUnableToDeleteRelationships, err) + numDeleted, err := safecast.ToUint64(result.RowsAffected()) + if err != nil { + return 0, fmt.Errorf("unable to cast rows affected to uint64: %w", err) } - return nil + return numDeleted, nil } func (rwt *pgReadWriteTXN) WriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { diff --git a/internal/datastore/proxy/observable.go b/internal/datastore/proxy/observable.go index 0b2744637a..649cd1fcdc 100644 --- a/internal/datastore/proxy/observable.go +++ b/internal/datastore/proxy/observable.go @@ -350,7 +350,7 @@ func (rwt *observableRWT) DeleteNamespaces(ctx context.Context, nsNames ...strin return rwt.delegate.DeleteNamespaces(ctx, nsNames...) } -func (rwt *observableRWT) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (bool, error) { +func (rwt *observableRWT) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (uint64, bool, error) { ctx, closer := observe(ctx, "DeleteRelationships", trace.WithAttributes( filterToAttributes(filter)..., )) diff --git a/internal/datastore/proxy/proxy_test/mock.go b/internal/datastore/proxy/proxy_test/mock.go index de32caec88..495862618b 100644 --- a/internal/datastore/proxy/proxy_test/mock.go +++ b/internal/datastore/proxy/proxy_test/mock.go @@ -269,9 +269,9 @@ func (dm *MockReadWriteTransaction) WriteRelationships(_ context.Context, mutati return args.Error(0) } -func (dm *MockReadWriteTransaction) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (bool, error) { +func (dm *MockReadWriteTransaction) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (uint64, bool, error) { args := dm.Called(filter) - return false, args.Error(0) + return 0, false, args.Error(0) } func (dm *MockReadWriteTransaction) WriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error { diff --git a/internal/datastore/spanner/readwrite.go b/internal/datastore/spanner/readwrite.go index 57cc91dd5b..b2571d32e5 100644 --- a/internal/datastore/spanner/readwrite.go +++ b/internal/datastore/spanner/readwrite.go @@ -147,22 +147,22 @@ func spannerMutation( return } -func (rwt spannerReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (bool, error) { - limitReached, err := deleteWithFilter(ctx, rwt.spannerRWT, filter, opts...) +func (rwt spannerReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { + numDeleted, limitReached, err := deleteWithFilter(ctx, rwt.spannerRWT, filter, opts...) if err != nil { - return false, fmt.Errorf(errUnableToDeleteRelationships, err) + return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) } - return limitReached, nil + return numDeleted, limitReached, nil } -func deleteWithFilter(ctx context.Context, rwt *spanner.ReadWriteTransaction, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (bool, error) { +func deleteWithFilter(ctx context.Context, rwt *spanner.ReadWriteTransaction, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...) var delLimit uint64 if delOpts.DeleteLimit != nil && *delOpts.DeleteLimit > 0 { delLimit = *delOpts.DeleteLimit if delLimit > inLimit { - return false, spiceerrors.MustBugf("delete limit %d exceeds maximum of %d in spanner", delLimit, inLimit) + return 0, false, spiceerrors.MustBugf("delete limit %d exceeds maximum of %d in spanner", delLimit, inLimit) } } @@ -170,13 +170,13 @@ func deleteWithFilter(ctx context.Context, rwt *spanner.ReadWriteTransaction, fi if delLimit > 0 { nu, err := deleteWithFilterAndLimit(ctx, rwt, filter, delLimit) if err != nil { - return false, err + return 0, false, err } numDeleted = nu } else { nu, err := deleteWithFilterAndNoLimit(ctx, rwt, filter) if err != nil { - return false, err + return 0, false, err } numDeleted = nu @@ -184,14 +184,14 @@ func deleteWithFilter(ctx context.Context, rwt *spanner.ReadWriteTransaction, fi uintNumDeleted, err := safecast.ToUint64(numDeleted) if err != nil { - return false, spiceerrors.MustBugf("numDeleted was negative: %v", err) + return 0, false, spiceerrors.MustBugf("numDeleted was negative: %v", err) } if delLimit > 0 && uintNumDeleted == delLimit { - return true, nil + return uintNumDeleted, true, nil } - return false, nil + return uintNumDeleted, false, nil } func deleteWithFilterAndLimit(ctx context.Context, rwt *spanner.ReadWriteTransaction, filter *v1.RelationshipFilter, delLimit uint64) (int64, error) { @@ -391,7 +391,7 @@ func (rwt spannerReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ... // Ensure the namespace exists. relFilter := &v1.RelationshipFilter{ResourceType: nsName} - if _, err := deleteWithFilter(ctx, rwt.spannerRWT, relFilter); err != nil { + if _, _, err := deleteWithFilter(ctx, rwt.spannerRWT, relFilter); err != nil { return fmt.Errorf(errUnableToDeleteConfig, err) } diff --git a/internal/dispatch/remote/cluster.go b/internal/dispatch/remote/cluster.go index 28d3df6253..c28e1e0131 100644 --- a/internal/dispatch/remote/cluster.go +++ b/internal/dispatch/remote/cluster.go @@ -3,89 +3,40 @@ package remote import ( "context" "errors" - "fmt" "io" - "strings" - "time" - "github.com/authzed/consistent" - "github.com/prometheus/client_golang/prometheus" - "github.com/rs/zerolog" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" - "google.golang.org/protobuf/proto" "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/keys" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/balancer" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" - "github.com/authzed/spicedb/pkg/spiceerrors" ) -var dispatchCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: "spicedb", - Subsystem: "dispatch", - Name: "remote_dispatch_handler_total", - Help: "which dispatcher handled a request", -}, []string{"request_kind", "handler_name"}) - -func init() { - prometheus.MustRegister(dispatchCounter) -} - -type ClusterClient interface { +type clusterClient interface { DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest, opts ...grpc.CallOption) (*v1.DispatchCheckResponse, error) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest, opts ...grpc.CallOption) (*v1.DispatchExpandResponse, error) - DispatchLookupResources2(ctx context.Context, in *v1.DispatchLookupResources2Request, opts ...grpc.CallOption) (v1.DispatchService_DispatchLookupResources2Client, error) + DispatchLookup(ctx context.Context, req *v1.DispatchLookupRequest, opts ...grpc.CallOption) (*v1.DispatchLookupResponse, error) + DispatchReachableResources(ctx context.Context, in *v1.DispatchReachableResourcesRequest, opts ...grpc.CallOption) (v1.DispatchService_DispatchReachableResourcesClient, error) DispatchLookupSubjects(ctx context.Context, in *v1.DispatchLookupSubjectsRequest, opts ...grpc.CallOption) (v1.DispatchService_DispatchLookupSubjectsClient, error) } -type ClusterDispatcherConfig struct { - // KeyHandler is then handler to use for generating dispatch hash ring keys. - KeyHandler keys.Handler - - // DispatchOverallTimeout is the maximum duration of a dispatched request - // before it should timeout. - DispatchOverallTimeout time.Duration -} - -// SecondaryDispatch defines a struct holding a client and its name for secondary -// dispatching. -type SecondaryDispatch struct { - Name string - Client ClusterClient -} - // NewClusterDispatcher creates a dispatcher implementation that uses the provided client // to dispatch requests to peer nodes in the cluster. -func NewClusterDispatcher(client ClusterClient, conn *grpc.ClientConn, config ClusterDispatcherConfig, secondaryDispatch map[string]SecondaryDispatch, secondaryDispatchExprs map[string]*DispatchExpr) dispatch.Dispatcher { - keyHandler := config.KeyHandler +func NewClusterDispatcher(client clusterClient, conn *grpc.ClientConn, keyHandler keys.Handler) dispatch.Dispatcher { if keyHandler == nil { keyHandler = &keys.DirectKeyHandler{} } - dispatchOverallTimeout := config.DispatchOverallTimeout - if dispatchOverallTimeout <= 0 { - dispatchOverallTimeout = 60 * time.Second - } - - return &clusterDispatcher{ - clusterClient: client, - conn: conn, - keyHandler: keyHandler, - dispatchOverallTimeout: dispatchOverallTimeout, - secondaryDispatch: secondaryDispatch, - secondaryDispatchExprs: secondaryDispatchExprs, - } + return &clusterDispatcher{clusterClient: client, conn: conn, keyHandler: keyHandler} } type clusterDispatcher struct { - clusterClient ClusterClient - conn *grpc.ClientConn - keyHandler keys.Handler - dispatchOverallTimeout time.Duration - secondaryDispatch map[string]SecondaryDispatch - secondaryDispatchExprs map[string]*DispatchExpr + clusterClient clusterClient + conn *grpc.ClientConn + keyHandler keys.Handler } func (cr *clusterDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) { @@ -98,326 +49,93 @@ func (cr *clusterDispatcher) DispatchCheck(ctx context.Context, req *v1.Dispatch return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, err } - ctx = context.WithValue(ctx, consistent.CtxKey, requestKey) - - resp, err := dispatchRequest(ctx, cr, "check", req, func(ctx context.Context, client ClusterClient) (*v1.DispatchCheckResponse, error) { - resp, err := client.DispatchCheck(ctx, req) - if err != nil { - return resp, err - } - - err = adjustMetadataForDispatch(resp.Metadata) - return resp, err - }) + ctx = context.WithValue(ctx, balancer.CtxKey, requestKey) + resp, err := cr.clusterClient.DispatchCheck(ctx, req) if err != nil { return &v1.DispatchCheckResponse{Metadata: requestFailureMetadata}, err } - return resp, err + return resp, nil } -type requestMessage interface { - zerolog.LogObjectMarshaler - - GetMetadata() *v1.ResolverMeta -} - -type responseMessage interface { - proto.Message - - GetMetadata() *v1.ResponseMeta -} - -type respTuple[S responseMessage] struct { - resp S - err error -} - -type secondaryRespTuple[S responseMessage] struct { - handlerName string - resp S -} - -func dispatchRequest[Q requestMessage, S responseMessage](ctx context.Context, cr *clusterDispatcher, reqKey string, req Q, handler func(context.Context, ClusterClient) (S, error)) (S, error) { - withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout) - defer cancelFn() - - if len(cr.secondaryDispatchExprs) == 0 || len(cr.secondaryDispatch) == 0 { - return handler(withTimeout, cr.clusterClient) +func (cr *clusterDispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) { + if err := dispatch.CheckDepth(ctx, req); err != nil { + return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err } - // If no secondary dispatches are defined, just invoke directly. - expr, ok := cr.secondaryDispatchExprs[reqKey] - if !ok { - return handler(withTimeout, cr.clusterClient) + requestKey, err := cr.keyHandler.ExpandDispatchKey(ctx, req) + if err != nil { + return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err } - // Otherwise invoke in parallel with any secondary matches. - primaryResultChan := make(chan respTuple[S], 1) - secondaryResultChan := make(chan secondaryRespTuple[S], len(cr.secondaryDispatch)) - - // Run the main dispatch. - go func() { - resp, err := handler(withTimeout, cr.clusterClient) - primaryResultChan <- respTuple[S]{resp, err} - }() - - result, err := RunDispatchExpr(expr, req) + ctx = context.WithValue(ctx, balancer.CtxKey, requestKey) + resp, err := cr.clusterClient.DispatchExpand(ctx, req) if err != nil { - log.Warn().Err(err).Msg("error when trying to evaluate the dispatch expression") + return &v1.DispatchExpandResponse{Metadata: requestFailureMetadata}, err } - log.Trace().Str("secondary-dispatchers", strings.Join(result, ",")).Object("request", req).Msg("running secondary dispatchers") - - for _, secondaryDispatchName := range result { - secondary, ok := cr.secondaryDispatch[secondaryDispatchName] - if !ok { - log.Warn().Str("secondary-dispatcher-name", secondaryDispatchName).Msg("received unknown secondary dispatcher") - continue - } + return resp, nil +} - log.Trace().Str("secondary-dispatcher", secondary.Name).Object("request", req).Msg("running secondary dispatcher") - go func() { - resp, err := handler(withTimeout, secondary.Client) - if err != nil { - // For secondary dispatches, ignore any errors, as only the primary will be handled in - // that scenario. - log.Trace().Str("secondary", secondary.Name).Err(err).Msg("got ignored secondary dispatch error") - return - } - - secondaryResultChan <- secondaryRespTuple[S]{resp: resp, handlerName: secondary.Name} - }() +func (cr *clusterDispatcher) DispatchLookup(ctx context.Context, req *v1.DispatchLookupRequest) (*v1.DispatchLookupResponse, error) { + if err := dispatch.CheckDepth(ctx, req); err != nil { + return &v1.DispatchLookupResponse{Metadata: emptyMetadata}, err } - var foundError error - select { - case <-withTimeout.Done(): - return *new(S), fmt.Errorf("check dispatch has timed out") - - case r := <-primaryResultChan: - if r.err == nil { - dispatchCounter.WithLabelValues(reqKey, "(primary)").Add(1) - return r.resp, nil - } - - // Otherwise, if an error was found, log it and we'll return after *all* the secondaries have run. - // This allows an otherwise error-state to be handled by one of the secondaries. - foundError = r.err - - case r := <-secondaryResultChan: - dispatchCounter.WithLabelValues(reqKey, r.handlerName).Add(1) - return r.resp, nil + requestKey, err := cr.keyHandler.LookupResourcesDispatchKey(ctx, req) + if err != nil { + return &v1.DispatchLookupResponse{Metadata: emptyMetadata}, err } - dispatchCounter.WithLabelValues(reqKey, "(primary)").Add(1) - return *new(S), foundError -} - -type requestMessageWithCursor interface { - requestMessage - GetOptionalCursor() *v1.Cursor -} - -type responseMessageWithCursor interface { - responseMessage - GetAfterResponseCursor() *v1.Cursor -} - -type receiver[S responseMessage] interface { - Recv() (S, error) - grpc.ClientStream -} - -const ( - secondaryCursorPrefix = "$$secondary:" - primaryDispatcher = "" -) - -func publishClient[Q requestMessageWithCursor, R responseMessageWithCursor](ctx context.Context, client receiver[R], stream dispatch.Stream[R], secondaryDispatchName string) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - - default: - result, err := client.Recv() - if errors.Is(err, io.EOF) { - return nil - } else if err != nil { - return err - } - - merr := adjustMetadataForDispatch(result.GetMetadata()) - if merr != nil { - return merr - } - - if secondaryDispatchName != primaryDispatcher { - afterResponseCursor := result.GetAfterResponseCursor() - if afterResponseCursor == nil { - return spiceerrors.MustBugf("received a nil after response cursor for secondary dispatch") - } - afterResponseCursor.Sections = append([]string{secondaryCursorPrefix + secondaryDispatchName}, afterResponseCursor.Sections...) - } - - serr := stream.Publish(result) - if serr != nil { - return serr - } - } + ctx = context.WithValue(ctx, balancer.CtxKey, requestKey) + resp, err := cr.clusterClient.DispatchLookup(ctx, req) + if err != nil { + return &v1.DispatchLookupResponse{Metadata: requestFailureMetadata}, err } + + return resp, nil } -// dispatchStreamingRequest handles the dispatching of a streaming request to the primary and any -// secondary dispatchers. Unlike the non-streaming version, this will first attempt to dispatch -// from the allowed secondary dispatchers before falling back to the primary, rather than running -// them in parallel. -func dispatchStreamingRequest[Q requestMessageWithCursor, R responseMessageWithCursor]( - ctx context.Context, - cr *clusterDispatcher, - reqKey string, - req Q, - stream dispatch.Stream[R], - handler func(context.Context, ClusterClient) (receiver[R], error), +func (cr *clusterDispatcher) DispatchReachableResources( + req *v1.DispatchReachableResourcesRequest, + stream dispatch.ReachableResourcesStream, ) error { - withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout) - defer cancelFn() - - client, err := handler(withTimeout, cr.clusterClient) + requestKey, err := cr.keyHandler.ReachableResourcesDispatchKey(stream.Context(), req) if err != nil { return err } - // Check the cursor to see if the dispatch went to one of the secondary endpoints. - cursor := req.GetOptionalCursor() - cursorLockedSecondaryName := "" - if cursor != nil && len(cursor.Sections) > 0 { - if strings.HasPrefix(cursor.Sections[0], secondaryCursorPrefix) { - cursorLockedSecondaryName = strings.TrimPrefix(cursor.Sections[0], secondaryCursorPrefix) - cursor.Sections = cursor.Sections[1:] - } - } - - // If no secondary dispatches are defined, just invoke directly. - if len(cr.secondaryDispatchExprs) == 0 || len(cr.secondaryDispatch) == 0 { - return publishClient[Q](withTimeout, client, stream, primaryDispatcher) - } - - // If the cursor is locked to a known secondary, dispatch to it. - if cursorLockedSecondaryName != "" { - secondary, ok := cr.secondaryDispatch[cursorLockedSecondaryName] - if ok { - secondaryClient, err := handler(withTimeout, secondary.Client) - if err != nil { - return err - } - - log.Debug().Str("secondary-dispatcher", secondary.Name).Object("request", req).Msg("running secondary dispatcher based on cursor") - return publishClient[Q](withTimeout, secondaryClient, stream, cursorLockedSecondaryName) - } - - return fmt.Errorf("unknown secondary dispatcher in cursor: %s", cursorLockedSecondaryName) - } + ctx := context.WithValue(stream.Context(), balancer.CtxKey, requestKey) + stream = dispatch.StreamWithContext(ctx, stream) - // Otherwise, look for a matching expression for the initial secondary dispatch - // and, if present, try to dispatch to it. - expr, ok := cr.secondaryDispatchExprs[reqKey] - if !ok { - return publishClient[Q](withTimeout, client, stream, primaryDispatcher) + if err := dispatch.CheckDepth(ctx, req); err != nil { + return err } - result, err := RunDispatchExpr(expr, req) + client, err := cr.clusterClient.DispatchReachableResources(ctx, req) if err != nil { - log.Warn().Err(err).Msg("error when trying to evaluate the dispatch expression") + return err } - for _, secondaryDispatchName := range result { - secondary, ok := cr.secondaryDispatch[secondaryDispatchName] - if !ok { - log.Warn().Str("secondary-dispatcher-name", secondaryDispatchName).Msg("received unknown secondary dispatcher") - continue + for { + result, err := client.Recv() + if errors.Is(err, io.EOF) { + break } - log.Trace().Str("secondary-dispatcher", secondary.Name).Object("request", req).Msg("running secondary dispatcher") - secondaryClient, err := handler(withTimeout, secondary.Client) if err != nil { - log.Warn().Str("secondary-dispatcher", secondary.Name).Err(err).Msg("failed to create secondary dispatch client") - continue + return err } - if err := publishClient[Q](withTimeout, secondaryClient, stream, secondaryDispatchName); err != nil { - log.Warn().Str("secondary-dispatcher", secondary.Name).Err(err).Msg("failed to publish secondary dispatch response") - continue + serr := stream.Publish(result) + if serr != nil { + return serr } - - return nil - } - - // Fallback: use the primary client if no secondary matched. - return publishClient[Q](withTimeout, client, stream, primaryDispatcher) -} - -func adjustMetadataForDispatch(metadata *v1.ResponseMeta) error { - if metadata == nil { - return spiceerrors.MustBugf("received a nil metadata") - } - - // NOTE: We only add 1 to the dispatch count if it was not already handled by the downstream dispatch, - // which will only be the case in a fully cached or further undispatched call. - if metadata.DispatchCount == 0 { - metadata.DispatchCount++ } return nil } -func (cr *clusterDispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) { - if err := dispatch.CheckDepth(ctx, req); err != nil { - return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err - } - - requestKey, err := cr.keyHandler.ExpandDispatchKey(ctx, req) - if err != nil { - return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err - } - - ctx = context.WithValue(ctx, consistent.CtxKey, requestKey) - - withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout) - defer cancelFn() - - resp, err := cr.clusterClient.DispatchExpand(withTimeout, req) - if err != nil { - return &v1.DispatchExpandResponse{Metadata: requestFailureMetadata}, err - } - - err = adjustMetadataForDispatch(resp.Metadata) - return resp, err -} - -func (cr *clusterDispatcher) DispatchLookupResources2( - req *v1.DispatchLookupResources2Request, - stream dispatch.LookupResources2Stream, -) error { - requestKey, err := cr.keyHandler.LookupResources2DispatchKey(stream.Context(), req) - if err != nil { - return err - } - - ctx := context.WithValue(stream.Context(), consistent.CtxKey, requestKey) - stream = dispatch.StreamWithContext(ctx, stream) - - if err := dispatch.CheckDepth(ctx, req); err != nil { - return err - } - - return dispatchStreamingRequest(ctx, cr, "lookupresources", req, stream, - func(ctx context.Context, client ClusterClient) (receiver[*v1.DispatchLookupResources2Response], error) { - return client.DispatchLookupResources2(ctx, req) - }) -} - func (cr *clusterDispatcher) DispatchLookupSubjects( req *v1.DispatchLookupSubjectsRequest, stream dispatch.LookupSubjectsStream, @@ -427,59 +145,46 @@ func (cr *clusterDispatcher) DispatchLookupSubjects( return err } - ctx := context.WithValue(stream.Context(), consistent.CtxKey, requestKey) + ctx := context.WithValue(stream.Context(), balancer.CtxKey, requestKey) stream = dispatch.StreamWithContext(ctx, stream) if err := dispatch.CheckDepth(ctx, req); err != nil { return err } - withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout) - defer cancelFn() - - client, err := cr.clusterClient.DispatchLookupSubjects(withTimeout, req) + client, err := cr.clusterClient.DispatchLookupSubjects(ctx, req) if err != nil { return err } for { - select { - case <-withTimeout.Done(): - return withTimeout.Err() - - default: - result, err := client.Recv() - if errors.Is(err, io.EOF) { - return nil - } else if err != nil { - return err - } - - merr := adjustMetadataForDispatch(result.Metadata) - if merr != nil { - return merr - } - - serr := stream.Publish(result) - if serr != nil { - return serr - } + result, err := client.Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return err + } + + serr := stream.Publish(result) + if serr != nil { + return serr } } + + return nil } func (cr *clusterDispatcher) Close() error { return nil } -// ReadyState returns whether the underlying dispatch connection is available -func (cr *clusterDispatcher) ReadyState() dispatch.ReadyState { +// IsReady returns whether the underlying dispatch connection is available +func (cr *clusterDispatcher) IsReady() bool { state := cr.conn.GetState() log.Trace().Interface("connection-state", state).Msg("checked if cluster dispatcher is ready") - return dispatch.ReadyState{ - IsReady: state == connectivity.Ready || state == connectivity.Idle, - Message: fmt.Sprintf("found expected state when trying to connect to cluster: %v", state), - } + return state == connectivity.Ready || state == connectivity.Idle } // Always verify that we implement the interface diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index c5fb52f354..78928f5ef0 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -476,7 +476,7 @@ func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.Del // Delete with the specified limit. if req.OptionalLimit > 0 { deleteLimit := uint64(req.OptionalLimit) - reachedLimit, err := rwt.DeleteRelationships(ctx, req.RelationshipFilter, options.WithDeleteLimit(&deleteLimit)) + _, reachedLimit, err := rwt.DeleteRelationships(ctx, req.RelationshipFilter, options.WithDeleteLimit(&deleteLimit)) if err != nil { return err } @@ -489,7 +489,7 @@ func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.Del } // Otherwise, kick off an unlimited deletion. - _, err = rwt.DeleteRelationships(ctx, req.RelationshipFilter) + _, _, err = rwt.DeleteRelationships(ctx, req.RelationshipFilter) return err }, options.WithMetadata(req.OptionalTransactionMetadata)) if err != nil { diff --git a/internal/testfixtures/validating.go b/internal/testfixtures/validating.go index e15db0fc4a..adcd85715e 100644 --- a/internal/testfixtures/validating.go +++ b/internal/testfixtures/validating.go @@ -225,9 +225,9 @@ func (vrwt validatingReadWriteTransaction) WriteRelationships(ctx context.Contex return vrwt.delegate.WriteRelationships(ctx, mutations) } -func (vrwt validatingReadWriteTransaction) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (bool, error) { +func (vrwt validatingReadWriteTransaction) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (uint64, bool, error) { if err := filter.Validate(); err != nil { - return false, err + return 0, false, err } return vrwt.delegate.DeleteRelationships(ctx, filter, options...) diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index 3d59327c52..ca9f6ba3f5 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -504,11 +504,12 @@ type ReadWriteTransaction interface { WriteRelationships(ctx context.Context, mutations []tuple.RelationshipUpdate) error // DeleteRelationships deletes relationships that match the provided filter, with - // the optional limit. If a limit is provided and reached, the method will return - // true as the first return value. Otherwise, the boolean can be ignored. + // the optional limit. Returnds the number of deleted relationships. If a limit + // is provided and reached, the method will return true as the second return value. + // Otherwise, the boolean can be ignored. DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption, - ) (bool, error) + ) (uint64, bool, error) // WriteNamespaces takes proto namespace definitions and persists them. WriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error diff --git a/pkg/datastore/test/relationships.go b/pkg/datastore/test/relationships.go index d80cca3797..5ed5436e13 100644 --- a/pkg/datastore/test/relationships.go +++ b/pkg/datastore/test/relationships.go @@ -241,7 +241,7 @@ func SimpleTest(t *testing.T, tester DatastoreTester) { // Delete with DeleteRelationship deletedAt, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + _, _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ ResourceType: testResourceNamespace, }) require.NoError(err) @@ -415,7 +415,7 @@ func DeleteRelationshipsTest(t *testing.T, tester DatastoreTester) { require.NoError(err) deletedAt, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - _, err := rwt.DeleteRelationships(ctx, tt.filter) + _, _, err := rwt.DeleteRelationships(ctx, tt.filter) require.NoError(err) return err }) @@ -698,7 +698,7 @@ func DeleteWithInvalidPrefixTest(t *testing.T, tester DatastoreTester) { ctx := context.Background() _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + _, _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ OptionalResourceIdPrefix: "hithere%", }) return err @@ -820,11 +820,12 @@ func DeleteWithLimitTest(t *testing.T, tester DatastoreTester) { // Delete 100 rels. var deleteLimit uint64 = 100 _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - limitReached, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + numDeleted, limitReached, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ ResourceType: testResourceNamespace, }, options.WithDeleteLimit(&deleteLimit)) require.NoError(err) require.True(limitReached) + require.Equal(deleteLimit, numDeleted) return nil }) require.NoError(err) @@ -836,11 +837,12 @@ func DeleteWithLimitTest(t *testing.T, tester DatastoreTester) { // Delete the remainder. deleteLimit = 1000 _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - limitReached, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + numDeleted, limitReached, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ ResourceType: testResourceNamespace, }, options.WithDeleteLimit(&deleteLimit)) require.NoError(err) require.False(limitReached) + require.Equal(uint64(900), numDeleted) return nil }) require.NoError(err) @@ -1001,7 +1003,7 @@ func DeleteRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTes // Delete the relationships and ensure matching are no longer found. _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - _, err := rwt.DeleteRelationships(ctx, tc.filter, options.WithDeleteLimit(delLimit)) + _, _, err := rwt.DeleteRelationships(ctx, tc.filter, options.WithDeleteLimit(delLimit)) return err }) require.NoError(err) @@ -1088,7 +1090,7 @@ func RecreateRelationshipsAfterDeleteWithFilter(t *testing.T, tester DatastoreTe deleteRelationships := func() error { _, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { delLimit := uint64(100) - _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + _, _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ OptionalRelation: "owner", OptionalSubjectFilter: &v1.SubjectFilter{ SubjectType: "user", @@ -1964,7 +1966,7 @@ func BulkDeleteRelationshipsTest(t *testing.T, tester DatastoreTester) { deletedRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { t.Log(time.Now(), "deleting") deleteCount++ - _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + _, _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ ResourceType: testResourceNamespace, OptionalRelation: testReaderRelation, }) diff --git a/pkg/datastore/test/watch.go b/pkg/datastore/test/watch.go index e01d74f29d..1886695931 100644 --- a/pkg/datastore/test/watch.go +++ b/pkg/datastore/test/watch.go @@ -111,7 +111,7 @@ func WatchTest(t *testing.T, tester DatastoreTester) { testUpdates = append(testUpdates, batch, []tuple.RelationshipUpdate{deleteUpdate}) _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + _, _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ ResourceType: testResourceNamespace, OptionalRelation: testReaderRelation, OptionalSubjectFilter: &v1.SubjectFilter{ diff --git a/pkg/datastore/util.go b/pkg/datastore/util.go index e72394790e..55bb8f2e37 100644 --- a/pkg/datastore/util.go +++ b/pkg/datastore/util.go @@ -28,7 +28,7 @@ func DeleteAllData(ctx context.Context, ds Datastore) error { // Delete all relationships. namespaceNames := make([]string, 0, len(nsDefs)) for _, nsDef := range nsDefs { - _, err = rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + _, _, err = rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ ResourceType: nsDef.Definition.Name, }) if err != nil {