Skip to content

Commit

Permalink
backport of commit 6eeb228
Browse files Browse the repository at this point in the history
  • Loading branch information
biazmoreira authored Mar 6, 2025
1 parent e65b742 commit 27c0a84
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 167 deletions.
3 changes: 2 additions & 1 deletion helper/storagepacker/storagepacker.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo
// Look for a matching storage entries and delete them from the list.
for i := 0; i < len(bucket.Items); i++ {
if _, ok := itemsToRemove[bucket.Items[i].ID]; ok {
bucket.Items[i] = bucket.Items[len(bucket.Items)-1]
copy(bucket.Items[i:], bucket.Items[i+1:])
bucket.Items[len(bucket.Items)-1] = nil // allow GC
bucket.Items = bucket.Items[:len(bucket.Items)-1]

// Since we just moved a value to position i we need to
Expand Down
44 changes: 44 additions & 0 deletions helper/storagepacker/storagepacker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package storagepacker
import (
"context"
"fmt"
"math/rand"
"testing"

"github.com/golang/protobuf/proto"
Expand Down Expand Up @@ -68,6 +69,49 @@ func BenchmarkStoragePacker(b *testing.B) {
}
}

func BenchmarkStoragePacker_DeleteMultiple(b *testing.B) {
b.StopTimer()
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
if err != nil {
b.Fatal(err)
}

ctx := context.Background()

// Persist a storage entry
for i := 0; i <= 100000; i++ {
item := &Item{
ID: fmt.Sprintf("item%d", i),
}

err = storagePacker.PutItem(ctx, item)
if err != nil {
b.Fatal(err)
}

// Verify that it can be read
fetchedItem, err := storagePacker.GetItem(item.ID)
if err != nil {
b.Fatal(err)
}
if fetchedItem == nil {
b.Fatalf("failed to read the stored item")
}

if item.ID != fetchedItem.ID {
b.Fatalf("bad: item ID; expected: %q\n actual: %q\n", item.ID, fetchedItem.ID)
}
}
b.StartTimer()

for i := 0; i < b.N; i++ {
err = storagePacker.DeleteItem(ctx, fmt.Sprintf("item%d", rand.Intn(100000)))
if err != nil {
b.Fatal(err)
}
}
}

func TestStoragePacker(t *testing.T) {
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
if err != nil {
Expand Down
23 changes: 23 additions & 0 deletions vault/external_tests/identity/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace"
ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap"
"github.com/hashicorp/vault/helper/testhelpers/minimal"
"github.com/hashicorp/vault/sdk/helper/ldaputil"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -642,3 +644,24 @@ func addRemoveLdapGroupMember(t *testing.T, cfg *ldaputil.ConfigEntry, userCN st
t.Fatal(err)
}
}

func findEntityFromDuplicateSet(t *testing.T, c *vault.TestClusterCore, entityIDs []string) *identity.Entity {
t.Helper()

var entity *identity.Entity

// Try fetch each ID and ensure exactly one is present
found := 0
for _, entityID := range entityIDs {
e, err := c.IdentityStore().MemDBEntityByID(entityID, true)
require.NoError(t, err)
if e != nil {
found++
entity = e
}
}
// More than one means they didn't merge as expected!
require.Equal(t, found, 1,
"node %s does not have exactly one duplicate from the set", c.NodeID)
return entity
}
4 changes: 2 additions & 2 deletions vault/identity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ func (i *IdentityStore) invalidateEntityBucket(ctx context.Context, key string)
}
}

err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false)
_, err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false, false)
if err != nil {
i.logger.Error("failed to update entity in MemDB", "entity_id", bucketEntity.ID, "error", err)
return
Expand Down Expand Up @@ -1416,7 +1416,7 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
}

// Update MemDB and persist entity object
err = i.upsertEntityInTxn(ctx, txn, entity, nil, true)
_, err = i.upsertEntityInTxn(ctx, txn, entity, nil, true, false)
if err != nil {
return entity, entityCreated, err
}
Expand Down
13 changes: 13 additions & 0 deletions vault/identity_store_conflicts.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type ConflictResolver interface {
ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error)
ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error)
ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error)
Reload(ctx context.Context)
}

// errorResolver is a ConflictResolver that logs a warning message when a
Expand Down Expand Up @@ -91,6 +92,10 @@ func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Ent
return false, errDuplicateIdentityName
}

// Reload is a no-op for the errorResolver implementation.
func (r *errorResolver) Reload(ctx context.Context) {
}

// duplicateReportingErrorResolver collects duplicate information and optionally
// logs a report on all the duplicates. We don't embed an errorResolver here
// because we _don't_ want it's side effect of warning on just some duplicates
Expand Down Expand Up @@ -144,6 +149,10 @@ func (r *duplicateReportingErrorResolver) ResolveAliases(ctx context.Context, pa
return false, errDuplicateIdentityName
}

func (r *duplicateReportingErrorResolver) Reload(ctx context.Context) {
r.seenEntities = make(map[string][]*identity.Entity)
}

type identityDuplicateReportEntry struct {
artifactType string
scope string
Expand Down Expand Up @@ -429,3 +438,7 @@ func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate
func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) {
return false, nil
}

// Reload is a no-op for the renameResolver implementation.
func (r *renameResolver) Reload(ctx context.Context) {
}
66 changes: 42 additions & 24 deletions vault/identity_store_entities.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ import (
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/identity/mfa"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/protobuf/types/known/anypb"
)

func entityPathFields() map[string]*framework.FieldSchema {
Expand Down Expand Up @@ -881,7 +879,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
}

fromEntity, err := i.MemDBEntityByID(fromEntityID, false)
fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, false)
if err != nil {
return nil, err, nil
}
Expand Down Expand Up @@ -984,7 +982,6 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
var fromEntityGroups []*identity.Group

toEntityAccessors := make(map[string][]string)

for _, alias := range toEntity.Aliases {
if accessors, ok := toEntityAccessors[alias.MountAccessor]; !ok {
// While it is not supported to have multiple aliases with the same mount accessor in one entity
Expand All @@ -1002,7 +999,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
}

fromEntity, err := i.MemDBEntityByID(fromEntityID, true)
fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, true)
if err != nil {
return nil, err, nil
}
Expand All @@ -1025,13 +1022,20 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
}

for _, fromAlias := range fromEntity.Aliases {
// We're going to modify this alias but it's still a pointer to the one in
// MemDB that could be being read by other goroutines even though we might
// be removing from MemDB really shortly...
fromAlias, err = fromAlias.Clone()
if err != nil {
return nil, err, nil
}
// If true, we need to handle conflicts (conflict = both aliases share the same mount accessor)
if toAliasIds, ok := toEntityAccessors[fromAlias.MountAccessor]; ok {
for _, toAliasId := range toAliasIds {
// When forceMergeAliases is true (as part of the merge-during-upsert case), we make the decision
// for the user, and keep the to_entity alias, merging the from_entity
// for the user, and keep the from_entity alias
// This case's code is the same as when the user selects to keep the from_entity alias
// but is kept separate for clarity
// but is kept separate for clarity.
if forceMergeAliases {
i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId)
err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false)
Expand All @@ -1046,8 +1050,8 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
if err != nil {
return nil, fmt.Errorf("aborting entity merge - failed to delete orphaned alias %q during merge into entity %q: %w", fromAlias.ID, toEntity.ID, err), nil
}
// Remove the alias from the entity's list in memory too!
toEntity.DeleteAliasByID(toAliasId)
// Don't need to alter toEntity aliases since we it never contained
// the alias we're deleting.

// Continue to next alias, as there's no alias to merge left in the from_entity
continue
Expand All @@ -1070,13 +1074,12 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit

fromAlias.MergedFromCanonicalIDs = append(fromAlias.MergedFromCanonicalIDs, fromEntity.ID)

err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false)
if err != nil {
return nil, fmt.Errorf("failed to update alias during merge: %w", err), nil
}
// We don't insert into MemDB right now because we'll do that for all the
// aliases we want to end up with at the end to ensure they are inserted
// in the same order as when they load from storage next time.

// Add the alias to the desired entity
toEntity.Aliases = append(toEntity.Aliases, fromAlias)
toEntity.UpsertAlias(fromAlias)
}

// If told to, merge policies
Expand Down Expand Up @@ -1124,6 +1127,30 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
}
}

// Normalize Alias order. We do this because we persist NonLocal and Local
// aliases separately and so after next reload local aliases will all come
// after non-local ones. While it's logically equivalent, it makes reasoning
// about merges and determinism very hard if the order of things in MemDB can
// change from one unseal to the next so we are especially careful to ensure
// it's exactly the same whether we just merged or on a subsequent load.
// persistEntities will already split these up and persist them separately, so
// we're kinda duplicating effort and code here but this should't happen often
// so I think it's fine.
nonLocalAliases, localAliases := splitLocalAliases(toEntity)
toEntity.Aliases = append(nonLocalAliases, localAliases...)

// Don't forget to insert aliases into alias table that were part of
// `toEntity` but were not merged above (because they didn't conflict). This
// might re-insert the same aliases we just inserted above again but that's a
// no-op. TODO: maybe we could remove the memdb updates in the loop above and
// have them all be inserted here.
for _, alias := range toEntity.Aliases {
err = i.MemDBUpsertAliasInTxn(txn, alias, false)
if err != nil {
return nil, err, nil
}
}

// Update MemDB with changes to the entity we are merging to
err = i.MemDBUpsertEntityInTxn(txn, toEntity)
if err != nil {
Expand All @@ -1140,16 +1167,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit

if persist && !isPerfSecondaryOrStandby {
// Persist the entity which we are merging to
toEntityAsAny, err := anypb.New(toEntity)
if err != nil {
return nil, err, nil
}
item := &storagepacker.Item{
ID: toEntity.ID,
Message: toEntityAsAny,
}

err = i.entityPacker.PutItem(ctx, item)
err = i.persistEntity(ctx, toEntity)
if err != nil {
return nil, err, nil
}
Expand Down
5 changes: 3 additions & 2 deletions vault/identity_store_injector_testonly.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,16 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc {
flags.Count = 2
}

ids, _, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags)
ids, bucketIds, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags)
if err != nil {
i.logger.Error("error creating duplicate entities", "error", err)
return logical.ErrorResponse("error creating duplicate entities"), err
}

return &logical.Response{
Data: map[string]interface{}{
"entity_ids": ids,
"entity_ids": ids,
"bucket_keys": bucketIds,
},
}, nil
}
Expand Down
6 changes: 3 additions & 3 deletions vault/identity_store_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ func TestOIDC_SignIDToken(t *testing.T) {

txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
_, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1020,7 +1020,7 @@ func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) {

txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
_, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1497,7 +1497,7 @@ func TestOIDC_Path_Introspect(t *testing.T) {

txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
_, err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit 27c0a84

Please sign in to comment.