diff --git a/cmd/end2endtest/helpers.go b/cmd/end2endtest/helpers.go index 7ac1d5509..88677afc3 100644 --- a/cmd/end2endtest/helpers.go +++ b/cmd/end2endtest/helpers.go @@ -29,6 +29,11 @@ const ( retriesSend = retries / 2 ) +type voterProof struct { + proof *apiclient.CensusProof + address string +} + func newTestElectionDescription(numChoices int) *vapi.ElectionDescription { choices := []vapi.ChoiceMetadata{} if numChoices < 2 { @@ -184,120 +189,103 @@ func (t *e2eElection) waitUntilElectionStarts(electionID types.HexBytes) (*vapi. } func (t *e2eElection) generateSIKProofs(root types.HexBytes) map[string]*apiclient.CensusProof { - type voterProof struct { - sikproof *apiclient.CensusProof - address string - } - sikProofs := make(map[string]*apiclient.CensusProof, len(t.voterAccounts)) - proofCh := make(chan *voterProof) - stopProofs := make(chan bool) - go func() { - for { - select { - case p := <-proofCh: - sikProofs[p.address] = p.sikproof - case <-stopProofs: - return - } - } - }() + var wg sync.WaitGroup + buffer := make(chan *voterProof, len(t.voterAccounts)) - addNaccounts := func(accounts []*ethereum.SignKeys, wg *sync.WaitGroup) { + addNaccounts := func(accounts []*ethereum.SignKeys) { defer wg.Done() log.Infof("generating %d sik proofs", len(accounts)) for _, acc := range accounts { voterProof := &voterProof{address: acc.Address().Hex()} - var err error voterPrivKey := acc.PrivateKey() voterApi := t.api.Clone(voterPrivKey.String()) - voterProof.sikproof, err = voterApi.GenSIKProof() + p, err := voterApi.GenSIKProof() if err != nil { log.Warn(err) } - proofCh <- voterProof + voterProof.proof = p + + buffer <- voterProof } } pcount := t.config.nvotes / t.config.parallelCount - var wg sync.WaitGroup for i := 0; i < len(t.voterAccounts); i += pcount { end := i + pcount if end > len(t.voterAccounts) { end = len(t.voterAccounts) } wg.Add(1) - go addNaccounts(t.voterAccounts[i:end], &wg) + go addNaccounts(t.voterAccounts[i:end]) } wg.Wait() - log.Debugf("%d/%d sik proofs generated successfully", len(sikProofs), len(t.voterAccounts)) - stopProofs <- true + close(buffer) - return sikProofs + proofs := make(map[string]*apiclient.CensusProof, len(t.voterAccounts)) + for p := range buffer { + proofs[p.address] = p.proof + } + + log.Debugf("%d/%d sik proofs generated successfully", len(proofs), len(t.voterAccounts)) + + return proofs } func (t *e2eElection) generateProofs(root types.HexBytes, isAnonymousVoting bool, csp *ethereum.SignKeys) map[string]*apiclient.CensusProof { - type voterProof struct { - proof *apiclient.CensusProof - address string - } - proofs := make(map[string]*apiclient.CensusProof, len(t.voterAccounts)) - proofCh := make(chan *voterProof) - stopProofs := make(chan bool) - go func() { - for { - select { - case p := <-proofCh: - proofs[p.address] = p.proof - case <-stopProofs: - return - } - } - }() + var wg sync.WaitGroup + buffer := make(chan *voterProof, len(t.voterAccounts)) - addNaccounts := func(accounts []*ethereum.SignKeys, wg *sync.WaitGroup) { + addNaccounts := func(accounts []*ethereum.SignKeys) { defer wg.Done() log.Infof("generating %d census proofs", len(accounts)) for _, acc := range accounts { voterProof := &voterProof{address: acc.Address().Hex()} - var err error if csp != nil { - voterProof.proof, err = cspGenProof(t.election.ElectionID, acc.Address().Bytes(), csp) + p, err := cspGenProof(t.election.ElectionID, acc.Address().Bytes(), csp) + if err != nil { + log.Warn(err) + } + voterProof.proof = p } else { voterPrivKey := acc.PrivateKey() voterApi := t.api.Clone(voterPrivKey.String()) - voterProof.proof, err = voterApi.CensusGenProof(root, acc.Address().Bytes()) + p, err := voterApi.CensusGenProof(root, acc.Address().Bytes()) if err != nil { log.Warn(err) } - } - if err != nil { - log.Warn(err) + voterProof.proof = p } if !isAnonymousVoting { voterProof.proof.KeyType = models.ProofArbo_ADDRESS } - proofCh <- voterProof + + buffer <- voterProof } } pcount := t.config.nvotes / t.config.parallelCount - var wg sync.WaitGroup for i := 0; i < len(t.voterAccounts); i += pcount { end := i + pcount if end > len(t.voterAccounts) { end = len(t.voterAccounts) } wg.Add(1) - go addNaccounts(t.voterAccounts[i:end], &wg) + go addNaccounts(t.voterAccounts[i:end]) } wg.Wait() + close(buffer) + + proofs := make(map[string]*apiclient.CensusProof, len(t.voterAccounts)) + for p := range buffer { + proofs[p.address] = p.proof + } + log.Debugf("%d/%d census proofs generated successfully", len(proofs), len(t.voterAccounts)) - stopProofs <- true return proofs } @@ -326,7 +314,7 @@ func (t *e2eElection) setupCensus(censusType string, nAcct int, createAccounts b // Register the accounts in the vochain if is required if censusType == vapi.CensusTypeZKWeighted && createAccounts { - errorChan := make(chan error, 1) // Create a buffered channel to prevent deadlock + errorChan := make(chan error, len(t.voterAccounts)) // Create a buffered channel to prevent deadlock wg := &sync.WaitGroup{} for i, acc := range t.voterAccounts {