Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mutex to prevent concurrent map access #428

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion mutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/sha256"
"fmt"
"sort"
"sync"

"github.com/pkg/errors"

Expand All @@ -29,6 +30,8 @@ type MutableTree struct {
versions map[int64]bool // The previous, saved versions of the tree.
allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion)
ndb *nodeDB

mtx sync.RWMutex // versions Read/write lock.
}

// NewMutableTree returns a new tree with the specified cache size and datastore.
Expand Down Expand Up @@ -59,6 +62,9 @@ func (tree *MutableTree) IsEmpty() bool {

// VersionExists returns whether or not a version exists.
func (tree *MutableTree) VersionExists(version int64) bool {
tree.mtx.RLock()
defer tree.mtx.RUnlock()

if tree.allRootLoaded {
return tree.versions[version]
}
Expand All @@ -74,6 +80,9 @@ func (tree *MutableTree) VersionExists(version int64) bool {

// AvailableVersions returns all available versions in ascending order
func (tree *MutableTree) AvailableVersions() []int {
tree.mtx.RLock()
defer tree.mtx.RUnlock()

res := make([]int, 0, len(tree.versions))
for i, v := range tree.versions {
if v {
Expand Down Expand Up @@ -318,6 +327,8 @@ func (tree *MutableTree) LazyLoadVersion(targetVersion int64) (int64, error) {
return latestVersion, ErrVersionDoesNotExist
}

tree.mtx.Lock()
defer tree.mtx.Unlock()
tree.versions[targetVersion] = true

iTree := &ImmutableTree{
Expand Down Expand Up @@ -354,6 +365,9 @@ func (tree *MutableTree) LoadVersion(targetVersion int64) (int64, error) {
firstVersion := int64(0)
latestVersion := int64(0)

tree.mtx.Lock()
defer tree.mtx.Unlock()

var latestRoot []byte
for version, r := range roots {
tree.versions[version] = true
Expand Down Expand Up @@ -411,6 +425,9 @@ func (tree *MutableTree) LoadVersionForOverwriting(targetVersion int64) (int64,

tree.ndb.resetLatestVersion(latestVersion)

tree.mtx.Lock()
defer tree.mtx.Unlock()

for v := range tree.versions {
if v > targetVersion {
delete(tree.versions, v)
Expand All @@ -429,7 +446,11 @@ func (tree *MutableTree) GetImmutable(version int64) (*ImmutableTree, error) {
}
if rootHash == nil {
return nil, ErrVersionDoesNotExist
} else if len(rootHash) == 0 {
}

tree.mtx.Lock()
defer tree.mtx.Unlock()
if len(rootHash) == 0 {
tree.versions[version] = true
return &ImmutableTree{
ndb: tree.ndb,
Expand Down Expand Up @@ -526,6 +547,8 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) {
return nil, version, err
}

tree.mtx.Lock()
defer tree.mtx.Unlock()
tree.version = version
tree.versions[version] = true

Expand Down Expand Up @@ -605,6 +628,8 @@ func (tree *MutableTree) DeleteVersionsRange(fromVersion, toVersion int64) error
return err
}

tree.mtx.Lock()
defer tree.mtx.Unlock()
for version := fromVersion; version < toVersion; version++ {
delete(tree.versions, version)
}
Expand All @@ -625,6 +650,8 @@ func (tree *MutableTree) DeleteVersion(version int64) error {
return err
}

tree.mtx.Lock()
defer tree.mtx.Unlock()
delete(tree.versions, version)
return nil
}
Expand Down
6 changes: 3 additions & 3 deletions nodedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ var (
)

type nodeDB struct {
mtx sync.Mutex // Read/write lock.
mtx sync.RWMutex // Read/write lock.
db dbm.DB // Persistent node storage.
batch dbm.Batch // Batched writing buffer.
opts Options // Options to customize for pruning/writing
Expand Down Expand Up @@ -68,8 +68,8 @@ func newNodeDB(db dbm.DB, cacheSize int, opts *Options) *nodeDB {
// GetNode gets a node from memory or disk. If it is an inner node, it does not
// load its children.
func (ndb *nodeDB) GetNode(hash []byte) *Node {
ndb.mtx.Lock()
defer ndb.mtx.Unlock()
ndb.mtx.RLock()
defer ndb.mtx.RUnlock()

if len(hash) == 0 {
panic("nodeDB.GetNode() requires hash")
Expand Down