From 7ebf1656386c746b84ebf8e82f42406acc3ebde9 Mon Sep 17 00:00:00 2001 From: 0xmad <0xmad@users.noreply.github.com> Date: Wed, 3 Jul 2024 12:38:19 -0500 Subject: [PATCH] feat: proof parallelization - [x] Prepare circuit inputs and run all proofs async - [x] Minor optimization for MACI contract --- contracts/contracts/MACI.sol | 5 ++- contracts/tasks/helpers/ProofGenerator.ts | 41 ++++++++++++------- contracts/tasks/helpers/types.ts | 5 +++ contracts/tasks/runner/prove.ts | 1 + coordinator/ts/proof/proof.service.ts | 23 ++++++----- .../smart-contracts/MACI.md | 2 +- 6 files changed, 50 insertions(+), 27 deletions(-) diff --git a/contracts/contracts/MACI.sol b/contracts/contracts/MACI.sol index 758322f901..b0971cec06 100644 --- a/contracts/contracts/MACI.sol +++ b/contracts/contracts/MACI.sol @@ -24,6 +24,8 @@ contract MACI is IMACI, DomainObjs, Params, Utilities { /// if we change the state tree depth! uint8 public immutable stateTreeDepth; + uint256 public immutable signUpsLimit; + uint8 internal constant TREE_ARITY = 2; uint8 internal constant MESSAGE_TREE_ARITY = 5; @@ -112,6 +114,7 @@ contract MACI is IMACI, DomainObjs, Params, Utilities { signUpGatekeeper = _signUpGatekeeper; initialVoiceCreditProxy = _initialVoiceCreditProxy; stateTreeDepth = _stateTreeDepth; + signUpsLimit = uint256(TREE_ARITY) ** uint256(_stateTreeDepth); // Verify linked poseidon libraries if (hash2([uint256(1), uint256(1)]) == 0) revert PoseidonHashLibrariesNotLinked(); @@ -135,7 +138,7 @@ contract MACI is IMACI, DomainObjs, Params, Utilities { bytes memory _initialVoiceCreditProxyData ) public virtual { // ensure we do not have more signups than what the circuits support - if (lazyIMTData.numberOfLeaves >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups(); + if (lazyIMTData.numberOfLeaves >= signUpsLimit) revert TooManySignups(); // ensure that the public key is on the baby jubjub curve if (!CurveBabyJubJub.isOnCurve(_pubKey.x, _pubKey.y)) { diff --git a/contracts/tasks/helpers/ProofGenerator.ts b/contracts/tasks/helpers/ProofGenerator.ts index 97bdf9569b..5416a1e701 100644 --- a/contracts/tasks/helpers/ProofGenerator.ts +++ b/contracts/tasks/helpers/ProofGenerator.ts @@ -76,8 +76,13 @@ export class ProofGenerator { maciPrivateKey, coordinatorKeypair, signer, + outputDir, options: { transactionHash, stateFile, startBlock, endBlock, blocksPerBatch }, }: IPrepareStateParams): Promise { + if (!fs.existsSync(path.resolve(outputDir))) { + await fs.promises.mkdir(path.resolve(outputDir)); + } + if (stateFile) { const content = JSON.parse(fs.readFileSync(stateFile).toString()) as unknown as IJsonMaciState; const serializedPrivateKey = maciPrivateKey.serialize(); @@ -175,7 +180,6 @@ export class ProofGenerator { performance.mark("mp-proofs-start"); console.log(`Generating proofs of message processing...`); - const proofs: Proof[] = []; const { messageBatchSize } = this.poll.batchSizes; const numMessages = this.poll.messages.length; let totalMessageBatches = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize); @@ -184,6 +188,8 @@ export class ProofGenerator { totalMessageBatches += 1; } + const inputs: CircuitInputs[] = []; + // while we have unprocessed messages, process them while (this.poll.hasUnprocessedMessages()) { // process messages in batches @@ -193,14 +199,19 @@ export class ProofGenerator { ) as unknown as CircuitInputs; // generate the proof for this batch - // eslint-disable-next-line no-await-in-loop - await this.generateProofs(circuitInputs, this.mp, `process_${this.poll.numBatchesProcessed - 1}.json`).then( - (data) => proofs.push(...data), - ); + inputs.push(circuitInputs); console.log(`Progress: ${this.poll.numBatchesProcessed} / ${totalMessageBatches}`); } + console.log("Wait until proof generation is finished"); + + const proofs = await Promise.all( + inputs.map((circuitInputs, index) => this.generateProofs(circuitInputs, this.mp, `process_${index}.json`)), + ).then((data) => data.reduce((acc, x) => acc.concat(x), [])); + + console.log("Proof generation is finished"); + performance.mark("mp-proofs-end"); performance.measure("Generate message processor proofs", "mp-proofs-start", "mp-proofs-end"); @@ -217,7 +228,6 @@ export class ProofGenerator { performance.mark("tally-proofs-start"); console.log(`Generating proofs of vote tallying...`); - const proofs: Proof[] = []; const { tallyBatchSize } = this.poll.batchSizes; const numStateLeaves = this.poll.stateLeaves.length; let totalTallyBatches = numStateLeaves <= tallyBatchSize ? 1 : Math.floor(numStateLeaves / tallyBatchSize); @@ -226,19 +236,26 @@ export class ProofGenerator { } let tallyCircuitInputs: CircuitInputs; + const inputs: CircuitInputs[] = []; + while (this.poll.hasUntalliedBallots()) { tallyCircuitInputs = (this.useQuadraticVoting ? this.poll.tallyVotes() : this.poll.tallyVotesNonQv()) as unknown as CircuitInputs; - // eslint-disable-next-line no-await-in-loop - await this.generateProofs(tallyCircuitInputs, this.tally, `tally_${this.poll.numBatchesTallied - 1}.json`).then( - (data) => proofs.push(...data), - ); + inputs.push(tallyCircuitInputs); console.log(`Progress: ${this.poll.numBatchesTallied} / ${totalTallyBatches}`); } + console.log("Wait until proof generation is finished"); + + const proofs = await Promise.all( + inputs.map((circuitInputs, index) => this.generateProofs(circuitInputs, this.tally, `tally_${index}.json`)), + ).then((data) => data.reduce((acc, x) => acc.concat(x), [])); + + console.log("Proof generation is finished"); + // verify the results // Compute newResultsCommitment const newResultsCommitment = genTreeCommitment( @@ -359,10 +376,6 @@ export class ProofGenerator { publicInputs: publicSignals, }); - if (!fs.existsSync(path.resolve(this.outputDir))) { - await fs.promises.mkdir(path.resolve(this.outputDir)); - } - await fs.promises.writeFile( path.resolve(this.outputDir, outputFile), JSON.stringify(proofs[proofs.length - 1], null, 4), diff --git a/contracts/tasks/helpers/types.ts b/contracts/tasks/helpers/types.ts index 1417ccf150..ce48f25677 100644 --- a/contracts/tasks/helpers/types.ts +++ b/contracts/tasks/helpers/types.ts @@ -258,6 +258,11 @@ export interface IPrepareStateParams { */ signer: Signer; + /** + * The directory to store the proofs + */ + outputDir: string; + /** * Options for state (on-chain fetching or local file) */ diff --git a/contracts/tasks/runner/prove.ts b/contracts/tasks/runner/prove.ts index 49259e3fce..9ebe997caf 100644 --- a/contracts/tasks/runner/prove.ts +++ b/contracts/tasks/runner/prove.ts @@ -109,6 +109,7 @@ task("prove", "Command to generate proof and prove the result of a poll on-chain coordinatorKeypair, pollId: poll, signer, + outputDir, options: { stateFile, transactionHash, diff --git a/coordinator/ts/proof/proof.service.ts b/coordinator/ts/proof/proof.service.ts index a8d249ec53..1200a1027f 100644 --- a/coordinator/ts/proof/proof.service.ts +++ b/coordinator/ts/proof/proof.service.ts @@ -61,8 +61,7 @@ export class ProofGeneratorService { address: maciContractAddress, }); - const signer = await this.deployment.getDeployer(); - const pollAddress = await maciContract.polls(poll); + const [signer, pollAddress] = await Promise.all([this.deployment.getDeployer(), maciContract.polls(poll)]); if (pollAddress.toLowerCase() === ZeroAddress.toLowerCase()) { this.logger.error(`Error: ${ErrorCodes.POLL_NOT_FOUND}, Poll ${poll} not found`); @@ -70,24 +69,23 @@ export class ProofGeneratorService { } const pollContract = await this.deployment.getContract({ name: EContracts.Poll, address: pollAddress }); - const [{ messageAq: messageAqAddress }, coordinatorPublicKey] = await Promise.all([ - pollContract.extContracts(), - pollContract.coordinatorPubKey(), - ]); + const [{ messageAq: messageAqAddress }, coordinatorPublicKey, isStateAqMerged, messageTreeDepth] = + await Promise.all([ + pollContract.extContracts(), + pollContract.coordinatorPubKey(), + pollContract.stateMerged(), + pollContract.treeDepths().then((depths) => Number(depths[2])), + ]); const messageAq = await this.deployment.getContract({ name: EContracts.AccQueue, address: messageAqAddress, }); - const isStateAqMerged = await pollContract.stateMerged(); - if (!isStateAqMerged) { this.logger.error(`Error: ${ErrorCodes.NOT_MERGED_STATE_TREE}, state tree is not merged`); throw new Error(ErrorCodes.NOT_MERGED_STATE_TREE); } - const messageTreeDepth = await pollContract.treeDepths().then((depths) => Number(depths[2])); - const mainRoot = await messageAq.getMainRoot(messageTreeDepth.toString()); if (mainRoot.toString() === "0") { @@ -108,6 +106,8 @@ export class ProofGeneratorService { throw new Error(ErrorCodes.PRIVATE_KEY_MISMATCH); } + const outputDir = path.resolve("./proofs"); + const maciState = await ProofGenerator.prepareState({ maciContract, pollContract, @@ -116,6 +116,7 @@ export class ProofGeneratorService { coordinatorKeypair, pollId: poll, signer, + outputDir, options: { startBlock, endBlock, @@ -137,7 +138,7 @@ export class ProofGeneratorService { tally: this.fileService.getZkeyFilePaths(process.env.COORDINATOR_TALLY_ZKEY_NAME!, useQuadraticVoting), mp: this.fileService.getZkeyFilePaths(process.env.COORDINATOR_MESSAGE_PROCESS_ZKEY_NAME!, useQuadraticVoting), rapidsnark: process.env.COORDINATOR_RAPIDSNARK_EXE, - outputDir: path.resolve("./proofs"), + outputDir, tallyOutputFile: path.resolve("./tally.json"), useQuadraticVoting, }); diff --git a/website/versioned_docs/version-v2.0_alpha/developers-references/smart-contracts/MACI.md b/website/versioned_docs/version-v2.0_alpha/developers-references/smart-contracts/MACI.md index ad6ce7aba5..da41ed461f 100644 --- a/website/versioned_docs/version-v2.0_alpha/developers-references/smart-contracts/MACI.md +++ b/website/versioned_docs/version-v2.0_alpha/developers-references/smart-contracts/MACI.md @@ -66,7 +66,7 @@ function signUp( bytes memory _initialVoiceCreditProxyData ) public virtual { // ensure we do not have more signups than what the circuits support - if (lazyIMTData.numberOfLeaves >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups(); + if (lazyIMTData.numberOfLeaves >= signUpsLimit) revert TooManySignups(); // ensure that the public key is on the baby jubjub curve if (!CurveBabyJubJub.isOnCurve(_pubKey.x, _pubKey.y)) {