Skip to content

Commit

Permalink
Add LoadOrCompute method to Map and MapOf (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
puzpuzpuz authored Oct 22, 2022
1 parent e5e826b commit 1ce54e2
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 13 deletions.
43 changes: 37 additions & 6 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,25 +167,52 @@ func (m *Map) Load(key string) (value interface{}, ok bool) {

// Store sets the value for a key.
func (m *Map) Store(key string, value interface{}) {
m.doStore(key, value, false)
m.doStore(
key,
func() interface{} {
return value
},
false,
)
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map) LoadOrStore(key string, value interface{}) (actual interface{}, loaded bool) {
return m.doStore(key, value, true)
return m.doStore(
key,
func() interface{} {
return value
},
true,
)
}

// LoadAndStore returns the existing value for the key if present,
// while setting the new value for the key.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false otherwise.
// It stores the new value and returns the existing one, if present.
// The loaded result is true if the existing value was loaded,
// false otherwise.
func (m *Map) LoadAndStore(key string, value interface{}) (actual interface{}, loaded bool) {
return m.doStore(key, value, false)
return m.doStore(
key,
func() interface{} {
return value
},
false,
)
}

// LoadOrCompute returns the existing value for the key if present.
// Otherwise, it computes the value using the provided function and
// returns the computed value. The loaded result is true if the value
// was loaded, false if stored.
func (m *Map) LoadOrCompute(key string, valueFn func() interface{}) (actual interface{}, loaded bool) {
return m.doStore(key, valueFn, true)
}

func (m *Map) doStore(key string, value interface{}, loadIfExists bool) (interface{}, bool) {
func (m *Map) doStore(key string, valueFn func() interface{}, loadIfExists bool) (interface{}, bool) {
// Read-only path.
if loadIfExists {
if v, ok := m.Load(key); ok {
Expand Down Expand Up @@ -241,12 +268,14 @@ func (m *Map) doStore(key string, value interface{}, loadIfExists bool) (interfa
// interface{} on each call, thus the live value pointers are
// unique. Otherwise atomic snapshot won't be correct in case
// of multiple Store calls using the same value.
value := valueFn()
nvp := unsafe.Pointer(&value)
if assertionsEnabled && vp == nvp {
panic("non-unique value pointer")
}
atomic.StorePointer(&b.values[i], nvp)
unlockBucket(&rootb.topHashMutex)
// LoadAndStore expects the old value to be returned.
return derefValue(vp), true
}
}
Expand All @@ -255,6 +284,7 @@ func (m *Map) doStore(key string, value interface{}, loadIfExists bool) (interfa
// Insertion case. First we update the value, then the key.
// This is important for atomic snapshot states.
atomic.StoreUint64(&b.topHashMutex, storeTopHash(hash, topHashes, emptyidx))
value := valueFn()
atomic.StorePointer(emptyvp, unsafe.Pointer(&value))
atomic.StorePointer(emptykp, unsafe.Pointer(&key))
unlockBucket(&rootb.topHashMutex)
Expand All @@ -271,6 +301,7 @@ func (m *Map) doStore(key string, value interface{}, loadIfExists bool) (interfa
// Create and append a new bucket.
newb := new(bucketPadded)
newb.keys[0] = unsafe.Pointer(&key)
value := valueFn()
newb.values[0] = unsafe.Pointer(&value)
newb.topHashMutex = storeTopHash(hash, topHashes, emptyidx)
atomic.StorePointer(&b.next, unsafe.Pointer(newb))
Expand Down
35 changes: 35 additions & 0 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ func TestMapLoadOrStore_NonNilValue(t *testing.T) {
if v != newv {
t.Errorf("value does not match: %v", v)
}
newv2 := &foo{}
v, loaded = m.LoadOrStore("foo", newv2)
if !loaded {
t.Error("value was expected")
}
if v != newv {
t.Errorf("value does not match: %v", v)
}
}

func TestMapLoadAndStore_NilValue(t *testing.T) {
Expand Down Expand Up @@ -307,6 +315,33 @@ func TestMapSerialLoadOrStore(t *testing.T) {
}
}

func TestMapSerialLoadOrCompute(t *testing.T) {
const numEntries = 1000
m := NewMap()
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrCompute(strconv.Itoa(i), func() interface{} {
return i
})
if loaded {
t.Errorf("value not computed for %d", i)
}
if vi, ok := v.(int); ok && vi != i {
t.Errorf("values do not match for %d: %v", i, v)
}
}
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrCompute(strconv.Itoa(i), func() interface{} {
return i
})
if !loaded {
t.Errorf("value not loaded for %d", i)
}
if vi, ok := v.(int); ok && vi != i {
t.Errorf("values do not match for %d: %v", i, v)
}
}
}

func TestMapSerialStoreThenDelete(t *testing.T) {
const numEntries = 1000
m := NewMap()
Expand Down
45 changes: 38 additions & 7 deletions mapof.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,52 @@ func (m *MapOf[K, V]) Load(key K) (value V, ok bool) {

// Store sets the value for a key.
func (m *MapOf[K, V]) Store(key K, value V) {
m.doStore(key, value, false)
m.doStore(
key,
func() V {
return value
},
false,
)
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *MapOf[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
return m.doStore(key, value, true)
return m.doStore(
key,
func() V {
return value
},
true,
)
}

// LoadAndStore returns the existing value for the key if present,
// while setting the new value for the key.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false otherwise.
// It stores the new value and returns the existing one, if present.
// The loaded result is true if the existing value was loaded,
// false otherwise.
func (m *MapOf[K, V]) LoadAndStore(key K, value V) (actual V, loaded bool) {
return m.doStore(key, value, false)
return m.doStore(
key,
func() V {
return value
},
false,
)
}

// LoadOrCompute returns the existing value for the key if present.
// Otherwise, it computes the value using the provided function and
// returns the computed value. The loaded result is true if the value
// was loaded, false if stored.
func (m *MapOf[K, V]) LoadOrCompute(key K, valueFn func() V) (actual V, loaded bool) {
return m.doStore(key, valueFn, true)
}

func (m *MapOf[K, V]) doStore(key K, value V, loadIfExists bool) (V, bool) {
func (m *MapOf[K, V]) doStore(key K, valueFn func() V, loadIfExists bool) (V, bool) {
// Read-only path.
if loadIfExists {
if v, ok := m.Load(key); ok {
Expand Down Expand Up @@ -188,13 +215,15 @@ func (m *MapOf[K, V]) doStore(key K, value V, loadIfExists bool) (V, bool) {
// interface{} on each call, thus the live value pointers are
// unique. Otherwise atomic snapshot won't be correct in case
// of multiple Store calls using the same value.
value := valueFn()
var wv interface{} = value
nvp := unsafe.Pointer(&wv)
if assertionsEnabled && vp == nvp {
panic("non-unique value pointer")
}
atomic.StorePointer(&b.values[i], nvp)
unlockBucket(&rootb.topHashMutex)
// LoadAndStore expects the old value to be returned.
return derefTypedValue[V](vp), true
}
}
Expand All @@ -203,7 +232,8 @@ func (m *MapOf[K, V]) doStore(key K, value V, loadIfExists bool) (V, bool) {
// Insertion case. First we update the value, then the key.
// This is important for atomic snapshot states.
atomic.StoreUint64(&b.topHashMutex, storeTopHash(hash, topHashes, emptyidx))
var wv interface{} = value
value := valueFn()
var wv interface{} = valueFn()
atomic.StorePointer(emptyvp, unsafe.Pointer(&wv))
atomic.StorePointer(emptykp, unsafe.Pointer(&key))
unlockBucket(&rootb.topHashMutex)
Expand All @@ -220,6 +250,7 @@ func (m *MapOf[K, V]) doStore(key K, value V, loadIfExists bool) (V, bool) {
// Create and append a new bucket.
newb := new(bucketPadded)
newb.keys[0] = unsafe.Pointer(&key)
value := valueFn()
var wv interface{} = value
newb.values[0] = unsafe.Pointer(&wv)
newb.topHashMutex = storeTopHash(hash, topHashes, emptyidx)
Expand Down
35 changes: 35 additions & 0 deletions mapof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ func TestMapOfLoadOrStore_NonNilValue(t *testing.T) {
if v != newv {
t.Errorf("value does not match: %v", v)
}
newv2 := &foo{}
v, loaded = m.LoadOrStore("foo", newv2)
if !loaded {
t.Error("value was expected")
}
if v != newv {
t.Errorf("value does not match: %v", v)
}
}

func TestMapOfLoadAndStore_NilValue(t *testing.T) {
Expand Down Expand Up @@ -354,6 +362,33 @@ func TestMapOfSerialLoadOrStore(t *testing.T) {
}
}

func TestMapOfSerialLoadOrCompute(t *testing.T) {
const numEntries = 1000
m := NewMapOf[int]()
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrCompute(strconv.Itoa(i), func() int {
return i
})
if loaded {
t.Errorf("value not computed for %d", i)
}
if v != i {
t.Errorf("values do not match for %d: %v", i, v)
}
}
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrCompute(strconv.Itoa(i), func() int {
return i
})
if !loaded {
t.Errorf("value not loaded for %d", i)
}
if v != i {
t.Errorf("values do not match for %d: %v", i, v)
}
}
}

func TestMapOfSerialStoreThenDelete(t *testing.T) {
const numEntries = 1000
m := NewMapOf[int]()
Expand Down

0 comments on commit 1ce54e2

Please sign in to comment.