diff --git a/consensus/istanbul/backend/snapshot.go b/consensus/istanbul/backend/snapshot.go index 08be3a25149e..4ab8d94256a4 100644 --- a/consensus/istanbul/backend/snapshot.go +++ b/consensus/istanbul/backend/snapshot.go @@ -22,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/istanbul" + "github.com/ethereum/go-ethereum/consensus/istanbul/validator" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" ) @@ -50,11 +51,11 @@ type Tally struct { type Snapshot struct { Epoch uint64 // The number of blocks after which to checkpoint and reset the pending votes - Number uint64 `json:"number"` // Block number where the snapshot was created - Hash common.Hash `json:"hash"` // Block hash where the snapshot was created - Votes []*Vote `json:"votes"` // List of votes cast in chronological order - Tally map[common.Address]Tally `json:"tally"` // Current vote tally to avoid recalculating - ValSet istanbul.ValidatorSet `json:"validators"` // Set of authorized validators at this moment + Number uint64 // Block number where the snapshot was created + Hash common.Hash // Block hash where the snapshot was created + Votes []*Vote // List of votes cast in chronological order + Tally map[common.Address]Tally // Current vote tally to avoid recalculating + ValSet istanbul.ValidatorSet // Set of authorized validators at this moment } // newSnapshot create a new snapshot with the specified startup parameters. This @@ -272,3 +273,49 @@ func (s *Snapshot) validators() []common.Address { } return validators } + +type snapshotJSON struct { + Epoch uint64 `json:"epoch"` + Number uint64 `json:"number"` + Hash common.Hash `json:"hash"` + Votes []*Vote `json:"votes"` + Tally map[common.Address]Tally `json:"tally"` + + // for validator set + Validators []common.Address `json:"validators"` + Policy istanbul.ProposerPolicy `json:"policy"` +} + +func (s *Snapshot) toJSONStruct() *snapshotJSON { + return &snapshotJSON{ + Epoch: s.Epoch, + Number: s.Number, + Hash: s.Hash, + Votes: s.Votes, + Tally: s.Tally, + Validators: s.validators(), + Policy: s.ValSet.Policy(), + } +} + +// Unmarshal from a json byte array +func (s *Snapshot) UnmarshalJSON(b []byte) error { + var j snapshotJSON + if err := json.Unmarshal(b, &j); err != nil { + return err + } + + s.Epoch = j.Epoch + s.Number = j.Number + s.Hash = j.Hash + s.Votes = j.Votes + s.Tally = j.Tally + s.ValSet = validator.NewSet(j.Validators, j.Policy) + return nil +} + +// Marshal to a json byte array +func (s *Snapshot) MarshalJSON() ([]byte, error) { + j := s.toJSONStruct() + return json.Marshal(j) +} diff --git a/consensus/istanbul/backend/snapshot_test.go b/consensus/istanbul/backend/snapshot_test.go index e00a05c97780..829f333b23e5 100644 --- a/consensus/istanbul/backend/snapshot_test.go +++ b/consensus/istanbul/backend/snapshot_test.go @@ -20,10 +20,12 @@ import ( "bytes" "crypto/ecdsa" "math/big" + "reflect" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/istanbul" + "github.com/ethereum/go-ethereum/consensus/istanbul/validator" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" @@ -400,3 +402,54 @@ func TestVoting(t *testing.T) { } } } + +func TestSaveAndLoad(t *testing.T) { + snap := &Snapshot{ + Epoch: 5, + Number: 10, + Hash: common.HexToHash("1234567890"), + Votes: []*Vote{ + { + Validator: common.StringToAddress("1234567891"), + Block: 15, + Address: common.StringToAddress("1234567892"), + Authorize: false, + }, + }, + Tally: map[common.Address]Tally{ + common.StringToAddress("1234567893"): Tally{ + Authorize: false, + Votes: 20, + }, + }, + ValSet: validator.NewSet([]common.Address{ + common.StringToAddress("1234567894"), + common.StringToAddress("1234567895"), + }, istanbul.RoundRobin), + } + db, _ := ethdb.NewMemDatabase() + err := snap.store(db) + if err != nil { + t.Errorf("store snapshot failed: %v", err) + } + + snap1, err := loadSnapshot(snap.Epoch, db, snap.Hash) + if err != nil { + t.Errorf("load snapshot failed: %v", err) + } + if snap.Epoch != snap1.Epoch { + t.Errorf("epoch mismatch: have %v, want %v", snap1.Epoch, snap.Epoch) + } + if snap.Hash != snap1.Hash { + t.Errorf("hash mismatch: have %v, want %v", snap1.Number, snap.Number) + } + if !reflect.DeepEqual(snap.Votes, snap.Votes) { + t.Errorf("votes mismatch: have %v, want %v", snap1.Votes, snap.Votes) + } + if !reflect.DeepEqual(snap.Tally, snap.Tally) { + t.Errorf("tally mismatch: have %v, want %v", snap1.Tally, snap.Tally) + } + if !reflect.DeepEqual(snap.ValSet, snap.ValSet) { + t.Errorf("validator set mismatch: have %v, want %v", snap1.ValSet, snap.ValSet) + } +} diff --git a/consensus/istanbul/validator.go b/consensus/istanbul/validator.go index a96487cf9fbb..e0d142866e10 100644 --- a/consensus/istanbul/validator.go +++ b/consensus/istanbul/validator.go @@ -71,6 +71,8 @@ type ValidatorSet interface { Copy() ValidatorSet // Get the maximum number of faulty nodes F() int + // Get proposer policy + Policy() ProposerPolicy } // ---------------------------------------------------------------------------- diff --git a/consensus/istanbul/validator/default.go b/consensus/istanbul/validator/default.go index f10c2beedb3c..17edda55216b 100644 --- a/consensus/istanbul/validator/default.go +++ b/consensus/istanbul/validator/default.go @@ -41,16 +41,18 @@ func (val *defaultValidator) String() string { // ---------------------------------------------------------------------------- type defaultSet struct { - validators istanbul.Validators + validators istanbul.Validators + policy istanbul.ProposerPolicy + proposer istanbul.Validator validatorMu sync.RWMutex - - selector istanbul.ProposalSelector + selector istanbul.ProposalSelector } -func newDefaultSet(addrs []common.Address, selector istanbul.ProposalSelector) *defaultSet { +func newDefaultSet(addrs []common.Address, policy istanbul.ProposerPolicy) *defaultSet { valSet := &defaultSet{} + valSet.policy = policy // init validators valSet.validators = make([]istanbul.Validator, len(addrs)) for i, addr := range addrs { @@ -62,8 +64,10 @@ func newDefaultSet(addrs []common.Address, selector istanbul.ProposalSelector) * if valSet.Size() > 0 { valSet.proposer = valSet.GetByIndex(0) } - //set proposal selector - valSet.selector = selector + valSet.selector = roundRobinProposer + if policy == istanbul.Sticky { + valSet.selector = stickyProposer + } return valSet } @@ -189,7 +193,9 @@ func (valSet *defaultSet) Copy() istanbul.ValidatorSet { for _, v := range valSet.validators { addresses = append(addresses, v.Address()) } - return newDefaultSet(addresses, valSet.selector) + return NewSet(addresses, valSet.policy) } func (valSet *defaultSet) F() int { return int(math.Ceil(float64(valSet.Size())/3)) - 1 } + +func (valSet *defaultSet) Policy() istanbul.ProposerPolicy { return valSet.policy } diff --git a/consensus/istanbul/validator/default_test.go b/consensus/istanbul/validator/default_test.go index dd8d63820228..987ed12c84f1 100644 --- a/consensus/istanbul/validator/default_test.go +++ b/consensus/istanbul/validator/default_test.go @@ -78,7 +78,7 @@ func testNormalValSet(t *testing.T) { val1 := New(addr1) val2 := New(addr2) - valSet := newDefaultSet([]common.Address{addr1, addr2}, roundRobinProposer) + valSet := newDefaultSet([]common.Address{addr1, addr2}, istanbul.RoundRobin) if valSet == nil { t.Errorf("the format of validator set is invalid") t.FailNow() @@ -182,7 +182,7 @@ func testStickyProposer(t *testing.T) { val1 := New(addr1) val2 := New(addr2) - valSet := newDefaultSet([]common.Address{addr1, addr2}, stickyProposer) + valSet := newDefaultSet([]common.Address{addr1, addr2}, istanbul.Sticky) // test get proposer if val := valSet.GetProposer(); !reflect.DeepEqual(val, val1) { diff --git a/consensus/istanbul/validator/validator.go b/consensus/istanbul/validator/validator.go index c4bdad61cbcf..9a1e15c2d8fc 100644 --- a/consensus/istanbul/validator/validator.go +++ b/consensus/istanbul/validator/validator.go @@ -28,15 +28,7 @@ func New(addr common.Address) istanbul.Validator { } func NewSet(addrs []common.Address, policy istanbul.ProposerPolicy) istanbul.ValidatorSet { - switch policy { - case istanbul.RoundRobin: - return newDefaultSet(addrs, roundRobinProposer) - case istanbul.Sticky: - return newDefaultSet(addrs, stickyProposer) - } - - // use round-robin policy as default proposal policy - return newDefaultSet(addrs, roundRobinProposer) + return newDefaultSet(addrs, policy) } func ExtractValidators(extraData []byte) []common.Address {