Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consortium-v2/snapshot: make FindAncientHeader more readable #424

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 53 additions & 30 deletions consensus/consortium/v2/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"
"github.com/hashicorp/golang-lru/arc/v2"
)
Expand Down Expand Up @@ -243,7 +244,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea
// Change the validator set base on the size of the validators set
if number > 0 && number%s.config.EpochV2 == uint64(len(snap.validators())/2) {
// Get the most recent checkpoint header
checkpointHeader := FindAncientHeader(header, uint64(len(snap.validators())/2), chain, parents)
checkpointHeader := findAncestorHeader(header, number-uint64(len(snap.validators())/2), chain, parents)
if checkpointHeader == nil {
return nil, consensus.ErrUnknownAncestor
}
Expand Down Expand Up @@ -420,36 +421,58 @@ func (s *Snapshot) IsRecentlySigned(validator common.Address) bool {
return false
}

// FindAncientHeader finds the most recent checkpoint header
// Travel through the candidateParents to find the ancient header.
// If all headers in candidateParents have the number is larger than the header number,
// the search function will return the index, but it is not valid if we check with the
// header since the number and hash is not equals. The candidateParents is
// only available when it downloads blocks from the network.
// Otherwise, the candidateParents is nil, and it will be found by header hash and number.
func FindAncientHeader(header *types.Header, ite uint64, chain consensus.ChainHeaderReader, candidateParents []*types.Header) *types.Header {
ancient := header
for i := uint64(1); i <= ite; i++ {
parentHash := ancient.ParentHash
parentHeight := ancient.Number.Uint64() - 1
found := false
if len(candidateParents) > 0 {
index := sort.Search(len(candidateParents), func(i int) bool {
return candidateParents[i].Number.Uint64() >= parentHeight
})
if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight &&
candidateParents[index].Hash() == parentHash {
ancient = candidateParents[index]
found = true
}
}
if !found {
ancient = chain.GetHeader(parentHash, parentHeight)
found = true
// findAncestorHeader traverses back to look for the requested ancestor header
// in parents list or in chaindata
//
// parents are guaranteed to be ordered and linked by the check when InsertChain
//
// There are 2 possible cases:
// Case 1: ancestor header is in parents list
// <- parents ->
// [ ancestorHeader ]
//
// Case 2: ancestor header's height is lower than parents list
// <- parents ->
// ancestorHeader ... [ ]

func findAncestorHeader(
currentHeader *types.Header,
ancestorBlockNumber uint64,
chain consensus.ChainHeaderReader,
parents []*types.Header,
) *types.Header {
// Find the first header in parents list that is higher or equal to checkpoint block
index := sort.Search(len(parents), func(i int) bool {
return parents[i].Number.Uint64() >= ancestorBlockNumber
})

// This must not happen, checkpoint header's height cannot be higher the parents list
if len(parents) != 0 && index >= len(parents) {
log.Warn(
"Checkpoint header's height is higher than parents list",
"checkpointNumber", ancestorBlockNumber,
"last parent", parents[len(parents)-1].Number,
)
return nil
}

if len(parents) != 0 && parents[index].Number.Uint64() == ancestorBlockNumber {
// Case 1: checkpoint header is in parents list
return parents[index]
} else {
// Case 2: checkpoint header's height is lower than parents list
var headerIterator *types.Header
if len(parents) != 0 {
headerIterator = parents[0]
} else {
headerIterator = currentHeader
}
if ancient == nil || !found {
return nil
for headerIterator.Number.Uint64() != ancestorBlockNumber {
headerIterator = chain.GetHeader(headerIterator.ParentHash, headerIterator.Number.Uint64()-1)
if headerIterator == nil {
return nil
}
}
return headerIterator
}
return ancient
}
106 changes: 106 additions & 0 deletions consensus/consortium/v2/snapshot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package v2

import (
"math/big"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/params"
)

type mockChainReader struct {
headerMapping map[common.Hash]*types.Header
}

func (chainReader *mockChainReader) Config() *params.ChainConfig { return nil }
func (chainReader *mockChainReader) CurrentHeader() *types.Header { return nil }
func (chainReader *mockChainReader) GetHeader(hash common.Hash, number uint64) *types.Header {
return chainReader.headerMapping[hash]
}
func (chainReader *mockChainReader) GetHeaderByNumber(number uint64) *types.Header { return nil }
func (chainReader *mockChainReader) GetHeaderByHash(hash common.Hash) *types.Header { return nil }
func (chainReader *mockChainReader) DB() ethdb.Database { return nil }
func (chainReader *mockChainReader) StateCache() state.Database { return nil }
func (chainReader *mockChainReader) OpEvents() []*vm.PublishEvent { return nil }

func TestFindCheckpointHeader(t *testing.T) {
// Case 1: checkpoint header is at block 5 (in parent list)
// parent list ranges from [0, 10)
parents := make([]*types.Header, 10)
for i := range parents {
parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))}
}

currentHeader := &types.Header{Number: big.NewInt(10)}
checkpointHeader := findAncestorHeader(currentHeader, 5, nil, parents)
if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.Coinbase != common.BigToAddress(big.NewInt(5)) {
t.Fatalf("Expect checkpoint header number: %d, got: %d", 5, checkpointHeader.Number.Int64())
}

// Case 2: checkpoint header is at 5 (lower than parent list)
// parent list ranges from [10, 20)
for i := range parents {
parents[i] = &types.Header{Number: big.NewInt(int64(i + 10)), ParentHash: common.BigToHash(big.NewInt(int64(i + 10 - 1)))}
}
mockChain := mockChainReader{
headerMapping: make(map[common.Hash]*types.Header),
}
// create mock chain 1
for i := 5; i < 10; i++ {
mockChain.headerMapping[common.BigToHash(big.NewInt(int64(100+i)))] = &types.Header{
Number: big.NewInt(int64(i)),
ParentHash: common.BigToHash(big.NewInt(int64(100 + i - 1))),
}
}

// create mock chain 2
for i := 5; i < 10; i++ {
mockChain.headerMapping[common.BigToHash(big.NewInt(int64(i)))] = &types.Header{
Number: big.NewInt(int64(i)),
ParentHash: common.BigToHash(big.NewInt(int64(i - 1))),
}
}

currentHeader = &types.Header{ParentHash: common.BigToHash(big.NewInt(19)), Number: big.NewInt(20)}
// Must traverse and get the correct header in chain 2
checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, parents)
if checkpointHeader == nil {
t.Fatal("Failed to find checkpoint header")
}
if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(4))) {
t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s",
5, common.BigToHash(big.NewInt(int64(4))),
checkpointHeader.Number.Int64(), checkpointHeader.ParentHash,
)
}

// Case 3: find checkpoint header with nil parent list
currentHeader = &types.Header{Number: big.NewInt(10), ParentHash: common.BigToHash(big.NewInt(109))}
checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, nil)
// Must traverse and get the correct header in chain 1
if checkpointHeader == nil {
t.Fatal("Failed to find checkpoint header")
}
if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(104))) {
t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s",
5, common.BigToHash(big.NewInt(int64(104))),
checkpointHeader.Number.Int64(), checkpointHeader.ParentHash,
)
}

// Case 4: checkpoint header is higher than parent list, this must not happen
// but the function must not crash in this case
// parent list ranges from [0, 10)
parents = make([]*types.Header, 10)
for i := range parents {
parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))}
}
checkpointHeader = findAncestorHeader(nil, 10, nil, parents)
if checkpointHeader != nil {
t.Fatalf("Expect %v checkpoint header, got %v", nil, checkpointHeader)
}
}
Loading