From f602eda2202b9da6e0674fe38bd9cca60efa15ae Mon Sep 17 00:00:00 2001 From: Will May Date: Thu, 27 Oct 2022 14:44:29 +0100 Subject: [PATCH] Fix concurrent map writes in MappedStorageProvider Use `sync.Map` to avoid asynchronously reading & writing to a map Fixes #952 --- pkg/storage/storage_provider.go | 46 ++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/pkg/storage/storage_provider.go b/pkg/storage/storage_provider.go index 08425380d6..6fa038e0fc 100644 --- a/pkg/storage/storage_provider.go +++ b/pkg/storage/storage_provider.go @@ -3,40 +3,40 @@ package storage import ( "context" "fmt" + "sync" "github.com/filecoin-project/bacalhau/pkg/model" ) -// A simple storage repo that selects a storage based on the job's storage type. +// MappedStorageProvider is a simple storage repo that selects a storage based on the job's storage type. type MappedStorageProvider struct { - storages map[model.StorageSourceType]Storage - storagesInstalledCache map[model.StorageSourceType]bool + storages *genericSyncMap[model.StorageSourceType, Storage] + storagesInstalledCache *genericSyncMap[model.StorageSourceType, bool] } func NewMappedStorageProvider(storages map[model.StorageSourceType]Storage) *MappedStorageProvider { return &MappedStorageProvider{ - storages: storages, - storagesInstalledCache: map[model.StorageSourceType]bool{}, + storages: genericMapFromMap(storages), + storagesInstalledCache: genericMapFromMap(map[model.StorageSourceType]bool{}), } } func (p *MappedStorageProvider) GetStorage(ctx context.Context, storageType model.StorageSourceType) (Storage, error) { - storage, ok := p.storages[storageType] + storage, ok := p.storages.Get(storageType) if !ok { - return nil, fmt.Errorf( - "no matching storage found on this server: %s", storageType) + return nil, fmt.Errorf("no matching storage found on this server: %s", storageType) } // cache it being installed so we're not hammering it // TODO: we should evict the cache in case an installed storage gets uninstalled, or vice versa - installed, ok := p.storagesInstalledCache[storageType] + installed, ok := p.storagesInstalledCache.Get(storageType) var err error if !ok { installed, err = storage.IsInstalled(ctx) if err != nil { return nil, err } - p.storagesInstalledCache[storageType] = installed + p.storagesInstalledCache.Put(storageType, installed) } if !installed { @@ -45,3 +45,29 @@ func (p *MappedStorageProvider) GetStorage(ctx context.Context, storageType mode return storage, nil } + +func genericMapFromMap[K comparable, V any](m map[K]V) *genericSyncMap[K, V] { + ret := &genericSyncMap[K, V]{} + for k, v := range m { + ret.Put(k, v) + } + + return ret +} + +type genericSyncMap[K comparable, V any] struct { + sync.Map +} + +func (m *genericSyncMap[K, V]) Get(key K) (V, bool) { + value, ok := m.Load(key) + if !ok { + var empty V + return empty, false + } + return value.(V), true +} + +func (m *genericSyncMap[K, V]) Put(key K, value V) { + m.Store(key, value) +}