From d51f48ffef36fec32009acdb77ccda8fcb0395da Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Tue, 25 Jul 2023 14:20:38 +0200 Subject: [PATCH 01/10] each subrelation can now choose to not scale relation term by random polynomial permutation relation arithmetic defined exclusively in relationclasses comments, naming cleanup init ecc_vm_relation_test, algebra test passes sumcheck relation test does not wip eccvm, shiftable polynomials now have 0 at first coefficient added clearer comments to eccvm relations removed redundant test files eccvm composer test passes eccvm flavor (+ upstream circuit builder) are parametrised by curve type / commitment scheme wip wip wip wip reduced eccvm relation lengths added more descriptive comments to ecc_set_relation cleaned up sumcheck relation fowarding macros added missing explicit template declarations slight cleanup of method name bloat added missing method removed unused using declaration added explicit `lookup_library.hpp` file to compute logderivative inverses typo fix formatting reversions added ECCVM concepts to flavor removed unused logderivative method each subrelation can now choose to not scale relation term by random polynomial permutation relation arithmetic defined exclusively in relationclasses comments, naming cleanup removed TypeMuncher from relation_types remove foundation --- .../honk/composer/eccvm_composer.cpp | 126 +++ .../honk/composer/eccvm_composer.hpp | 72 ++ .../honk/composer/eccvm_composer.test.cpp | 93 ++ cpp/src/barretenberg/honk/flavor/ecc_vm.hpp | 857 ++++++++++++++++++ .../honk/proof_system/eccvm_prover.cpp | 386 ++++++++ .../honk/proof_system/eccvm_prover.hpp | 85 ++ .../honk/proof_system/eccvm_verifier.cpp | 256 ++++++ .../honk/proof_system/eccvm_verifier.hpp | 47 + .../honk/proof_system/lookup_library.hpp | 64 ++ .../honk/proof_system/permutation_library.hpp | 165 ++++ .../honk/proof_system/prover_library.hpp | 11 +- .../honk/sumcheck/polynomials/univariate.hpp | 18 + .../relations/ecc_vm/ecc_vm_relation.test.cpp | 362 ++++++++ .../relations/relation_definitions_fwd.hpp | 45 + .../relations/relation_parameters.hpp | 3 + .../sumcheck/relations/relation_types.hpp | 60 +- .../eccvm/eccvm_builder_types.hpp | 36 + .../eccvm/eccvm_circuit_builder.hpp | 489 ++++++++++ .../eccvm/eccvm_circuit_builder.test.cpp | 227 +++++ .../circuit_builder/eccvm/msm_builder.hpp | 263 ++++++ .../eccvm/precomputed_tables_builder.hpp | 112 +++ .../eccvm/transcript_builder.hpp | 175 ++++ .../proof_system/flavor/flavor.hpp | 4 + .../relations/ecc_vm/ecc_lookup_relation.cpp | 89 ++ .../relations/ecc_vm/ecc_lookup_relation.hpp | 259 ++++++ .../relations/ecc_vm/ecc_msm_relation.cpp | 401 ++++++++ .../relations/ecc_vm/ecc_msm_relation.hpp | 100 ++ .../ecc_vm/ecc_point_table_relation.cpp | 176 ++++ .../ecc_vm/ecc_point_table_relation.hpp | 47 + .../relations/ecc_vm/ecc_set_relation.cpp | 393 ++++++++ .../relations/ecc_vm/ecc_set_relation.hpp | 66 ++ .../ecc_vm/ecc_transcript_relation.cpp | 255 ++++++ .../ecc_vm/ecc_transcript_relation.hpp | 92 ++ .../relations/ecc_vm/ecc_wnaf_relation.cpp | 210 +++++ .../relations/ecc_vm/ecc_wnaf_relation.hpp | 71 ++ 35 files changed, 6099 insertions(+), 16 deletions(-) create mode 100644 cpp/src/barretenberg/honk/composer/eccvm_composer.cpp create mode 100644 cpp/src/barretenberg/honk/composer/eccvm_composer.hpp create mode 100644 cpp/src/barretenberg/honk/composer/eccvm_composer.test.cpp create mode 100644 cpp/src/barretenberg/honk/flavor/ecc_vm.hpp create mode 100644 cpp/src/barretenberg/honk/proof_system/eccvm_prover.cpp create mode 100644 cpp/src/barretenberg/honk/proof_system/eccvm_prover.hpp create mode 100644 cpp/src/barretenberg/honk/proof_system/eccvm_verifier.cpp create mode 100644 cpp/src/barretenberg/honk/proof_system/eccvm_verifier.hpp create mode 100644 cpp/src/barretenberg/honk/proof_system/lookup_library.hpp create mode 100644 cpp/src/barretenberg/honk/proof_system/permutation_library.hpp create mode 100644 cpp/src/barretenberg/honk/sumcheck/relations/ecc_vm/ecc_vm_relation.test.cpp create mode 100644 cpp/src/barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp create mode 100644 cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_builder_types.hpp create mode 100644 cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp create mode 100644 cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.test.cpp create mode 100644 cpp/src/barretenberg/proof_system/circuit_builder/eccvm/msm_builder.hpp create mode 100644 cpp/src/barretenberg/proof_system/circuit_builder/eccvm/precomputed_tables_builder.hpp create mode 100644 cpp/src/barretenberg/proof_system/circuit_builder/eccvm/transcript_builder.hpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.cpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_msm_relation.cpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_msm_relation.hpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_point_table_relation.cpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_point_table_relation.hpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_set_relation.cpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_set_relation.hpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_transcript_relation.cpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_transcript_relation.hpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_wnaf_relation.cpp create mode 100644 cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_wnaf_relation.hpp diff --git a/cpp/src/barretenberg/honk/composer/eccvm_composer.cpp b/cpp/src/barretenberg/honk/composer/eccvm_composer.cpp new file mode 100644 index 0000000000..2975cac32d --- /dev/null +++ b/cpp/src/barretenberg/honk/composer/eccvm_composer.cpp @@ -0,0 +1,126 @@ +#include "./eccvm_composer.hpp" +#include "barretenberg/honk/proof_system/ultra_prover.hpp" +#include "barretenberg/proof_system/composer/composer_lib.hpp" +#include "barretenberg/proof_system/composer/permutation_lib.hpp" + +namespace proof_system::honk { + +/** + * @brief Compute witness polynomials + * + */ +template +void ECCVMComposerHelper_::compute_witness(CircuitConstructor& circuit_constructor) +{ + if (computed_witness) { + return; + } + + auto polynomials = circuit_constructor.compute_full_polynomials(); + + auto key_wires = proving_key->get_wires(); + auto poly_wires = polynomials.get_wires(); + + for (size_t i = 0; i < key_wires.size(); ++i) { + std::copy(poly_wires[i].begin(), poly_wires[i].end(), key_wires[i].begin()); + } + + computed_witness = true; +} + +template +ECCVMProver_ ECCVMComposerHelper_::create_prover(CircuitConstructor& circuit_constructor) +{ + compute_proving_key(circuit_constructor); + compute_witness(circuit_constructor); + compute_commitment_key(proving_key->circuit_size); + + ECCVMProver_ output_state(proving_key, commitment_key); + + return output_state; +} + +/** + * Create verifier: compute verification key, + * initialize verifier with it and an initial manifest and initialize commitment_scheme. + * + * @return The verifier. + * */ +template +ECCVMVerifier_ ECCVMComposerHelper_::create_verifier(CircuitConstructor& circuit_constructor) +{ + auto verification_key = compute_verification_key(circuit_constructor); + + ECCVMVerifier_ output_state(verification_key); + + auto pcs_verification_key = std::make_unique(verification_key->circuit_size, crs_factory_); + + output_state.pcs_verification_key = std::move(pcs_verification_key); + + return output_state; +} + +template +std::shared_ptr ECCVMComposerHelper_::compute_proving_key( + CircuitConstructor& circuit_constructor) +{ + if (proving_key) { + return proving_key; + } + + // Initialize proving_key + // TODO(#392)(Kesha): replace composer types. + { + // TODO: get num gates in a more efficient way + const auto rows = circuit_constructor.compute_full_polynomials(); + const size_t subgroup_size = rows.lagrange_first.size(); + // Differentiate between Honk and Plonk here since Plonk pkey requires crs whereas Honk pkey does not + proving_key = std::make_shared(subgroup_size, 0); + } + + // construct_selector_polynomials(circuit_constructor, proving_key.get()); + + // TODO(@zac-williamson): We don't enforce nonzero selectors atm. Will create problems in recursive setting. Needs + // fix enforce_nonzero_polynomial_selectors(circuit_constructor, proving_key.get()); + + compute_first_and_last_lagrange_polynomials(proving_key.get()); + { + const size_t n = proving_key->circuit_size; + typename Flavor::Polynomial lagrange_polynomial_second(n); + lagrange_polynomial_second[1] = 1; + proving_key->lagrange_second = lagrange_polynomial_second; + } + + proving_key->contains_recursive_proof = false; + + return proving_key; +} + +/** + * Compute verification key consisting of selector precommitments. + * + * @return Pointer to created circuit verification key. + * */ +template +std::shared_ptr ECCVMComposerHelper_::compute_verification_key( + CircuitConstructor& circuit_constructor) +{ + if (verification_key) { + return verification_key; + } + + if (!proving_key) { + compute_proving_key(circuit_constructor); + } + + verification_key = std::make_shared( + proving_key->circuit_size, proving_key->num_public_inputs); + + verification_key->lagrange_first = commitment_key->commit(proving_key->lagrange_first); + verification_key->lagrange_second = commitment_key->commit(proving_key->lagrange_second); + verification_key->lagrange_last = commitment_key->commit(proving_key->lagrange_last); + return verification_key; +} +template class ECCVMComposerHelper_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/composer/eccvm_composer.hpp b/cpp/src/barretenberg/honk/composer/eccvm_composer.hpp new file mode 100644 index 0000000000..8c9fe7425e --- /dev/null +++ b/cpp/src/barretenberg/honk/composer/eccvm_composer.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "barretenberg/honk/proof_system/eccvm_prover.hpp" +#include "barretenberg/honk/proof_system/eccvm_verifier.hpp" +#include "barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp" +#include "barretenberg/proof_system/composer/composer_lib.hpp" +#include "barretenberg/srs/factories/file_crs_factory.hpp" + +namespace proof_system::honk { +template class ECCVMComposerHelper_ { + public: + using CircuitConstructor = ECCVMCircuitConstructor; + using ProvingKey = typename Flavor::ProvingKey; + using VerificationKey = typename Flavor::VerificationKey; + using PCSParams = typename Flavor::PCSParams; + using PCS = typename Flavor::PCS; + using PCSCommitmentKey = typename PCSParams::CommitmentKey; + using PCSVerificationKey = typename PCSParams::VerificationKey; + + static constexpr std::string_view NAME_STRING = "ECCVM"; + static constexpr size_t NUM_RESERVED_GATES = 0; // equal to the number of multilinear evaluations leaked + static constexpr size_t NUM_WIRES = CircuitConstructor::NUM_WIRES; + std::shared_ptr proving_key; + std::shared_ptr verification_key; + + // The crs_factory holds the path to the srs and exposes methods to extract the srs elements + std::shared_ptr crs_factory_; + + // The commitment key is passed to the prover but also used herein to compute the verfication key commitments + std::shared_ptr commitment_key; + + std::vector recursive_proof_public_input_indices; + bool contains_recursive_proof = false; + bool computed_witness = false; + + ECCVMComposerHelper_() + : crs_factory_(barretenberg::srs::get_crs_factory()){}; + + explicit ECCVMComposerHelper_(std::shared_ptr crs_factory) + : crs_factory_(std::move(crs_factory)) + {} + + ECCVMComposerHelper_(std::shared_ptr p_key, std::shared_ptr v_key) + : proving_key(std::move(p_key)) + , verification_key(std::move(v_key)) + {} + + ECCVMComposerHelper_(ECCVMComposerHelper_&& other) noexcept = default; + ECCVMComposerHelper_(ECCVMComposerHelper_ const& other) noexcept = default; + ECCVMComposerHelper_& operator=(ECCVMComposerHelper_&& other) noexcept = default; + ECCVMComposerHelper_& operator=(ECCVMComposerHelper_ const& other) noexcept = default; + ~ECCVMComposerHelper_() = default; + + std::shared_ptr compute_proving_key(CircuitConstructor& circuit_constructor); + std::shared_ptr compute_verification_key(CircuitConstructor& circuit_constructor); + + void compute_witness(CircuitConstructor& circuit_constructor); + + ECCVMProver_ create_prover(CircuitConstructor& circuit_constructor); + ECCVMVerifier_ create_verifier(CircuitConstructor& circuit_constructor); + + void add_table_column_selector_poly_to_proving_key(polynomial& small, const std::string& tag); + + void compute_commitment_key(size_t circuit_size) + { + commitment_key = std::make_shared(circuit_size, crs_factory_); + }; +}; +extern template class ECCVMComposerHelper_; +// TODO(#532): this pattern is weird; is this not instantiating the templates? +using ECCVMComposerHelper = ECCVMComposerHelper_; +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/composer/eccvm_composer.test.cpp b/cpp/src/barretenberg/honk/composer/eccvm_composer.test.cpp new file mode 100644 index 0000000000..e117b16f3d --- /dev/null +++ b/cpp/src/barretenberg/honk/composer/eccvm_composer.test.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include + +#include "barretenberg/honk/composer/eccvm_composer.hpp" +#include "barretenberg/honk/proof_system/prover.hpp" +#include "barretenberg/honk/sumcheck/relations/permutation_relation.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "barretenberg/honk/sumcheck/sumcheck_round.hpp" +#include "barretenberg/honk/utils/grand_product_delta.hpp" +#include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/proof_system/circuit_constructors/eccvm/eccvm_circuit_builder.hpp" + +using namespace proof_system::honk; + +namespace test_standard_honk_composer { + +class ECCVMComposerTests : public ::testing::Test { + protected: + static void SetUpTestSuite() { barretenberg::srs::init_crs_factory("../srs_db/ignition"); } +}; +namespace { +auto& engine = numeric::random::get_debug_engine(); +} +proof_system::ECCVMCircuitConstructor generate_trace(numeric::random::Engine* engine = nullptr) +{ + proof_system::ECCVMCircuitConstructor result; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::g1::element b = grumpkin::get_generator(1); + grumpkin::g1::element c = grumpkin::get_generator(2); + grumpkin::fr x = grumpkin::fr::random_element(engine); + grumpkin::fr y = grumpkin::fr::random_element(engine); + + grumpkin::g1::element expected_1 = (a * x) + a + a + (b * y) + (b * x) + (b * x); + grumpkin::g1::element expected_2 = (a * x) + c + (b * x); + + result.add_accumulate(a); + result.mul_accumulate(a, x); + result.mul_accumulate(b, x); + result.mul_accumulate(b, y); + result.add_accumulate(a); + result.mul_accumulate(b, x); + result.eq(expected_1); + result.add_accumulate(c); + result.mul_accumulate(a, x); + result.mul_accumulate(b, x); + result.eq(expected_2); + result.mul_accumulate(a, x); + result.mul_accumulate(b, x); + result.mul_accumulate(c, x); + + return result; +} + +TEST_F(ECCVMComposerTests, BaseCase) +{ + auto circuit_constructor = generate_trace(&engine); + + auto composer = ECCVMComposerHelper(); + auto prover = composer.create_prover(circuit_constructor); + + // / size_t pidx = 0; + // for (auto& p : prover.prover_polynomials) { + // size_t count = 0; + // for (auto& x : p) { + // std::cout << "poly[" << pidx << "][" << count << "] = " << x << std::endl; + // count++; + // } + // pidx++; + // } + auto proof = prover.construct_proof(); + auto verifier = composer.create_verifier(circuit_constructor); + bool verified = verifier.verify_proof(proof); + ASSERT_TRUE(verified); +} + +TEST_F(ECCVMComposerTests, EqFails) +{ + auto circuit_constructor = generate_trace(&engine); + // create an eq opcode that is not satisfied + circuit_constructor.eq(grumpkin::g1::affine_one); + auto composer = ECCVMComposerHelper(); + auto prover = composer.create_prover(circuit_constructor); + + auto proof = prover.construct_proof(); + auto verifier = composer.create_verifier(circuit_constructor); + bool verified = verifier.verify_proof(proof); + ASSERT_FALSE(verified); +} +} // namespace test_standard_honk_composer diff --git a/cpp/src/barretenberg/honk/flavor/ecc_vm.hpp b/cpp/src/barretenberg/honk/flavor/ecc_vm.hpp new file mode 100644 index 0000000000..3428506447 --- /dev/null +++ b/cpp/src/barretenberg/honk/flavor/ecc_vm.hpp @@ -0,0 +1,857 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "barretenberg/honk/pcs/commitment_key.hpp" +#include "barretenberg/honk/sumcheck/polynomials/univariate.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_wnaf_relation.hpp" +#include "barretenberg/proof_system/flavor/flavor.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_msm_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_point_table_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_set_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_transcript_relation.hpp" +#include "../sumcheck/relations/relation_types.hpp" +#include "../sumcheck/relations/relation_definitions_fwd.hpp" +#include "barretenberg/honk/pcs/kzg/kzg.hpp" +#include "barretenberg/honk/pcs/ipa/ipa.hpp" + +// NOLINTBEGIN(cppcoreguidelines-avoid-const-or-ref-data-members) + +namespace proof_system::honk { +namespace flavor { + +template typename PCS_T> +class ECCVMBase { + public: + using CycleGroup = CycleGroup_T; + // forward template params into the ECCVMBase namespace + using G1 = G1_T; + using PCSParams = PCSParams_T; + using PCS = PCS_T; + + using FF = typename G1::subgroup_field; + using Polynomial = barretenberg::Polynomial; + using PolynomialHandle = std::span; + using GroupElement = G1::element; + using Commitment = G1::affine_element; + using CommitmentHandle = G1::affine_element; + + static constexpr size_t NUM_WIRES = 74; + + // The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often + // need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`. + // Note: this number does not include the individual sorted list polynomials. + static constexpr size_t NUM_ALL_ENTITIES = 105; + // The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying + // assignment of witnesses. We again choose a neutral name. + static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 3; + // The total number of witness entities not including shifts. + static constexpr size_t NUM_WITNESS_ENTITIES = 76; + + using GrandProductRelations = std::tuple>; + // define the tuple of Relations that comprise the Sumcheck relation + using Relations = std::tuple, + sumcheck::ECCVMPointTableRelation, + sumcheck::ECCVMWnafRelation, + sumcheck::ECCVMMSMRelation, + sumcheck::ECCVMSetRelation, + sumcheck::ECCVMLookupRelation>; + + using LookupRelation = sumcheck::ECCVMLookupRelation; + static constexpr size_t MAX_RELATION_LENGTH = get_max_relation_length(); + + // MAX_RANDOM_RELATION_LENGTH = algebraic degree of sumcheck relation *after* multiplying by the `pow_zeta` random + // polynomial e.g. For \sum(x) [A(x) * B(x) + C(x)] * PowZeta(X), relation length = 2 and random relation length = 3 + static constexpr size_t MAX_RANDOM_RELATION_LENGTH = MAX_RELATION_LENGTH + 1; + static constexpr size_t NUM_RELATIONS = std::tuple_size::value; + + // Instantiate the BarycentricData needed to extend each Relation Univariate + // static_assert(instantiate_barycentric_utils()); + + // define the containers for storing the contributions from each relation in Sumcheck + using RelationUnivariates = decltype(create_relation_univariates_container()); + using RelationValues = decltype(create_relation_values_container()); + + private: + // class Counter { + // constexpr size_t foo() + // { + // return Thing<>; + // } + // }; + /** + * @brief A base class labelling precomputed entities and (ordered) subsets of interest. + * @details Used to build the proving key and verification key. + */ + template + class PrecomputedEntities : public PrecomputedEntities_ { + public: + DataType& lagrange_first = std::get<0>(this->_data); + DataType& lagrange_second = std::get<1>(this->_data); + DataType& lagrange_last = std::get<2>(this->_data); + + std::vector get_selectors() override { return { lagrange_first, lagrange_second, lagrange_last }; }; + std::vector get_sigma_polynomials() override { return {}; }; + std::vector get_id_polynomials() override { return {}; }; + std::vector get_table_polynomials() { return {}; }; + }; + + /** + * @brief Container for all witness polynomials used/constructed by the prover. + * @details Shifts are not included here since they do not occupy their own memory. + */ + template + class WitnessEntities : public WitnessEntities_ { + public: + // clang-format off + DataType& q_transcript_add = std::get<0>(this->_data); + DataType& q_transcript_mul = std::get<1>(this->_data); + DataType& q_transcript_eq = std::get<2>(this->_data); + DataType& q_transcript_accumulate = std::get<3>(this->_data); + DataType& q_transcript_msm_transition = std::get<4>(this->_data); + DataType& transcript_pc = std::get<5>(this->_data); + DataType& transcript_msm_count = std::get<6>(this->_data); + DataType& transcript_x = std::get<7>(this->_data); + DataType& transcript_y = std::get<8>(this->_data); + DataType& transcript_z1 = std::get<9>(this->_data); + DataType& transcript_z2 = std::get<10>(this->_data); + DataType& transcript_z1zero = std::get<11>(this->_data); + DataType& transcript_z2zero = std::get<12>(this->_data); + DataType& transcript_op = std::get<13>(this->_data); + DataType& transcript_accumulator_x = std::get<14>(this->_data); + DataType& transcript_accumulator_y = std::get<15>(this->_data); + DataType& transcript_msm_x = std::get<16>(this->_data); + DataType& transcript_msm_y = std::get<17>(this->_data); + DataType& table_pc = std::get<18>(this->_data); + DataType& table_point_transition = std::get<19>(this->_data); + DataType& table_round = std::get<20>(this->_data); + DataType& table_scalar_sum = std::get<21>(this->_data); + DataType& table_s1 = std::get<22>(this->_data); + DataType& table_s2 = std::get<23>(this->_data); + DataType& table_s3 = std::get<24>(this->_data); + DataType& table_s4 = std::get<25>(this->_data); + DataType& table_s5 = std::get<26>(this->_data); + DataType& table_s6 = std::get<27>(this->_data); + DataType& table_s7 = std::get<28>(this->_data); + DataType& table_s8 = std::get<29>(this->_data); + DataType& table_skew = std::get<30>(this->_data); + DataType& table_dx = std::get<31>(this->_data); + DataType& table_dy = std::get<32>(this->_data); + DataType& table_tx = std::get<33>(this->_data); + DataType& table_ty = std::get<34>(this->_data); + DataType& q_msm_transition = std::get<35>(this->_data); + DataType& msm_q_add = std::get<36>(this->_data); + DataType& msm_q_double = std::get<37>(this->_data); + DataType& msm_q_skew = std::get<38>(this->_data); + DataType& msm_accumulator_x = std::get<39>(this->_data); + DataType& msm_accumulator_y = std::get<40>(this->_data); + DataType& msm_pc = std::get<41>(this->_data); + DataType& msm_size_of_msm = std::get<42>(this->_data); + DataType& msm_count = std::get<43>(this->_data); + DataType& msm_round = std::get<44>(this->_data); + DataType& msm_q_add1 = std::get<45>(this->_data); + DataType& msm_q_add2 = std::get<46>(this->_data); + DataType& msm_q_add3 = std::get<47>(this->_data); + DataType& msm_q_add4 = std::get<48>(this->_data); + DataType& msm_x1 = std::get<49>(this->_data); + DataType& msm_y1 = std::get<50>(this->_data); + DataType& msm_x2 = std::get<51>(this->_data); + DataType& msm_y2 = std::get<52>(this->_data); + DataType& msm_x3 = std::get<53>(this->_data); + DataType& msm_y3 = std::get<54>(this->_data); + DataType& msm_x4 = std::get<55>(this->_data); + DataType& msm_y4 = std::get<56>(this->_data); + DataType& msm_collision_x1 = std::get<57>(this->_data); + DataType& msm_collision_x2 = std::get<58>(this->_data); + DataType& msm_collision_x3 = std::get<59>(this->_data); + DataType& msm_collision_x4 = std::get<60>(this->_data); + DataType& msm_lambda1 = std::get<61>(this->_data); + DataType& msm_lambda2 = std::get<62>(this->_data); + DataType& msm_lambda3 = std::get<63>(this->_data); + DataType& msm_lambda4 = std::get<64>(this->_data); + DataType& msm_slice1 = std::get<65>(this->_data); + DataType& msm_slice2 = std::get<66>(this->_data); + DataType& msm_slice3 = std::get<67>(this->_data); + DataType& msm_slice4 = std::get<68>(this->_data); + DataType& transcript_accumulator_empty = std::get<69>(this->_data); + DataType& transcript_q_reset_accumulator = std::get<70>(this->_data); + DataType& q_wnaf = std::get<71>(this->_data); + DataType& lookup_read_counts_0 = std::get<72>(this->_data); + DataType& lookup_read_counts_1 = std::get<73>(this->_data); + DataType& z_perm = std::get<74>(this->_data); + DataType& lookup_inverses = std::get<75>(this->_data); + + // clang-format on + std::vector get_wires() override + { + return { + q_transcript_add, + q_transcript_mul, + q_transcript_eq, + q_transcript_accumulate, + q_transcript_msm_transition, + transcript_pc, + transcript_msm_count, + transcript_x, + transcript_y, + transcript_z1, + transcript_z2, + transcript_z1zero, + transcript_z2zero, + transcript_op, + transcript_accumulator_x, + transcript_accumulator_y, + transcript_msm_x, + transcript_msm_y, + table_pc, + table_point_transition, + table_round, + table_scalar_sum, + table_s1, + table_s2, + table_s3, + table_s4, + table_s5, + table_s6, + table_s7, + table_s8, + table_skew, + table_dx, + table_dy, + table_tx, + table_ty, + q_msm_transition, + msm_q_add, + msm_q_double, + msm_q_skew, + msm_accumulator_x, + msm_accumulator_y, + msm_pc, + msm_size_of_msm, + msm_count, + msm_round, + msm_q_add1, + msm_q_add2, + msm_q_add3, + msm_q_add4, + msm_x1, + msm_y1, + msm_x2, + msm_y2, + msm_x3, + msm_y3, + msm_x4, + msm_y4, + msm_collision_x1, + msm_collision_x2, + msm_collision_x3, + msm_collision_x4, + msm_lambda1, + msm_lambda2, + msm_lambda3, + msm_lambda4, + msm_slice1, + msm_slice2, + msm_slice3, + msm_slice4, + transcript_accumulator_empty, + transcript_q_reset_accumulator, + q_wnaf, + lookup_read_counts_0, + lookup_read_counts_1, + }; + }; + // The sorted concatenations of table and witness data needed for plookup. + std::vector get_sorted_polynomials() { return {}; }; + }; + + /** + * @brief A base class labelling all entities (for instance, all of the polynomials used by the prover during + * sumcheck) in this Honk variant along with particular subsets of interest + * @details Used to build containers for: the prover's polynomial during sumcheck; the sumcheck's folded + * polynomials; the univariates consturcted during during sumcheck; the evaluations produced by sumcheck. + * + * Symbolically we have: AllEntities = PrecomputedEntities + WitnessEntities + "ShiftedEntities". It could be + * implemented as such, but we have this now. + */ + // SUEHRGFPIEAUHFPAWEIUFHEAWP9UFH NEED TO MAKE SURE POINTS ARE NOT POINTS AT INFINITY + // I3EUBFPEWUBEWOPFUHEWPIFUHEPWFUHEQWOIFUHEOLRFHEQPFUHEQPFUH I.E. ALL ARE NONZERO EPFHUEPFHUEPGRFGHBEWOFIEHUPOFUHRNF + template + class AllEntities : public AllEntities_ { + public: + // clang-format off + DataType& lagrange_first = std::get<0>(this->_data); + DataType& lagrange_second = std::get<1>(this->_data); + DataType& lagrange_last = std::get<2>(this->_data); + DataType& q_transcript_add = std::get<3 + 0>(this->_data); + DataType& q_transcript_mul = std::get<3 + 1>(this->_data); + DataType& q_transcript_eq = std::get<3 + 2>(this->_data); + DataType& q_transcript_accumulate = std::get<3 + 3>(this->_data); + DataType& q_transcript_msm_transition = std::get<3 + 4>(this->_data); + DataType& transcript_pc = std::get<3 + 5>(this->_data); + DataType& transcript_msm_count = std::get<3 + 6>(this->_data); + DataType& transcript_x = std::get<3 + 7>(this->_data); + DataType& transcript_y = std::get<3 + 8>(this->_data); + DataType& transcript_z1 = std::get<3 + 9>(this->_data); + DataType& transcript_z2 = std::get<3 + 10>(this->_data); + DataType& transcript_z1zero = std::get<3 + 11>(this->_data); + DataType& transcript_z2zero = std::get<3 + 12>(this->_data); + DataType& transcript_op = std::get<3 + 13>(this->_data); + DataType& transcript_accumulator_x = std::get<3 + 14>(this->_data); + DataType& transcript_accumulator_y = std::get<3 + 15>(this->_data); + DataType& transcript_msm_x = std::get<3 + 16>(this->_data); + DataType& transcript_msm_y = std::get<3 + 17>(this->_data); + DataType& table_pc = std::get<3 + 18>(this->_data); + DataType& table_point_transition = std::get<3 + 19>(this->_data); + DataType& table_round = std::get<3 + 20>(this->_data); + DataType& table_scalar_sum = std::get<3 + 21>(this->_data); + DataType& table_s1 = std::get<3 + 22>(this->_data); + DataType& table_s2 = std::get<3 + 23>(this->_data); + DataType& table_s3 = std::get<3 + 24>(this->_data); + DataType& table_s4 = std::get<3 + 25>(this->_data); + DataType& table_s5 = std::get<3 + 26>(this->_data); + DataType& table_s6 = std::get<3 + 27>(this->_data); + DataType& table_s7 = std::get<3 + 28>(this->_data); + DataType& table_s8 = std::get<3 + 29>(this->_data); + DataType& table_skew = std::get<3 + 30>(this->_data); + DataType& table_dx = std::get<3 + 31>(this->_data); + DataType& table_dy = std::get<3 + 32>(this->_data); + DataType& table_tx = std::get<3 + 33>(this->_data); + DataType& table_ty = std::get<3 + 34>(this->_data); + DataType& q_msm_transition = std::get<3 + 35>(this->_data); + DataType& msm_q_add = std::get<3 + 36>(this->_data); + DataType& msm_q_double = std::get<3 + 37>(this->_data); + DataType& msm_q_skew = std::get<3 + 38>(this->_data); + DataType& msm_accumulator_x = std::get<3 + 39>(this->_data); + DataType& msm_accumulator_y = std::get<3 + 40>(this->_data); + DataType& msm_pc = std::get<3 + 41>(this->_data); + DataType& msm_size_of_msm = std::get<3 + 42>(this->_data); + DataType& msm_count = std::get<3 + 43>(this->_data); + DataType& msm_round = std::get<3 + 44>(this->_data); + DataType& msm_q_add1 = std::get<3 + 45>(this->_data); + DataType& msm_q_add2 = std::get<3 + 46>(this->_data); + DataType& msm_q_add3 = std::get<3 + 47>(this->_data); + DataType& msm_q_add4 = std::get<3 + 48>(this->_data); + DataType& msm_x1 = std::get<3 + 49>(this->_data); + DataType& msm_y1 = std::get<3 + 50>(this->_data); + DataType& msm_x2 = std::get<3 + 51>(this->_data); + DataType& msm_y2 = std::get<3 + 52>(this->_data); + DataType& msm_x3 = std::get<3 + 53>(this->_data); + DataType& msm_y3 = std::get<3 + 54>(this->_data); + DataType& msm_x4 = std::get<3 + 55>(this->_data); + DataType& msm_y4 = std::get<3 + 56>(this->_data); + DataType& msm_collision_x1 = std::get<3 + 57>(this->_data); + DataType& msm_collision_x2 = std::get<3 + 58>(this->_data); + DataType& msm_collision_x3 = std::get<3 + 59>(this->_data); + DataType& msm_collision_x4 = std::get<3 + 60>(this->_data); + DataType& msm_lambda1 = std::get<3 + 61>(this->_data); + DataType& msm_lambda2 = std::get<3 + 62>(this->_data); + DataType& msm_lambda3 = std::get<3 + 63>(this->_data); + DataType& msm_lambda4 = std::get<3 + 64>(this->_data); + DataType& msm_slice1 = std::get<3 + 65>(this->_data); + DataType& msm_slice2 = std::get<3 + 66>(this->_data); + DataType& msm_slice3 = std::get<3 + 67>(this->_data); + DataType& msm_slice4 = std::get<3 + 68>(this->_data); + DataType& transcript_accumulator_empty = std::get<3 + 69>(this->_data); + DataType& transcript_q_reset_accumulator = std::get<3 + 70>(this->_data); + DataType& q_wnaf = std::get<3 + 71>(this->_data); + DataType& lookup_read_counts_0 = std::get<3 + 72>(this->_data); + DataType& lookup_read_counts_1 = std::get<3 + 73>(this->_data); + DataType& z_perm = std::get<3 + 74>(this->_data); + DataType& lookup_inverses = std::get<3 + 75>(this->_data); + DataType& q_transcript_mul_shift = std::get<3 + 76>(this->_data); + DataType& q_transcript_accumulate_shift = std::get<3 + 77>(this->_data); + DataType& transcript_msm_count_shift = std::get<3 + 78>(this->_data); + DataType& transcript_accumulator_x_shift = std::get<3 + 79>(this->_data); + DataType& transcript_accumulator_y_shift = std::get<3 + 80>(this->_data); + DataType& table_scalar_sum_shift = std::get<3 + 81>(this->_data); + DataType& table_dx_shift = std::get<3 + 82>(this->_data); + DataType& table_dy_shift = std::get<3 + 83>(this->_data); + DataType& table_tx_shift = std::get<3 + 84>(this->_data); + DataType& table_ty_shift = std::get<3 + 85>(this->_data); + DataType& q_msm_transition_shift = std::get<3 + 86>(this->_data); + DataType& msm_q_add_shift = std::get<3 + 87>(this->_data); + DataType& msm_q_double_shift = std::get<3 + 88>(this->_data); + DataType& msm_q_skew_shift = std::get<3 + 89>(this->_data); + DataType& msm_accumulator_x_shift = std::get<3 + 90>(this->_data); + DataType& msm_accumulator_y_shift = std::get<3 + 91>(this->_data); + DataType& msm_count_shift = std::get<3 + 92>(this->_data); + DataType& msm_round_shift = std::get<3 + 93>(this->_data); + DataType& msm_q_add1_shift = std::get<3 + 94>(this->_data); + DataType& msm_pc_shift = std::get<3 + 95>(this->_data); + DataType& table_pc_shift = std::get<3 + 96>(this->_data); + DataType& transcript_pc_shift = std::get<3 + 97>(this->_data); + DataType& table_round_shift = std::get<3 + 98>(this->_data); + DataType& transcript_accumulator_empty_shift= std::get<3 + 99>(this->_data); + DataType& q_wnaf_shift = std::get<3 + 100>(this->_data); + DataType& z_perm_shift = std::get<3 + 101>(this->_data); + + template + [[nodiscard]] const DataType& lookup_read_counts() const + { + static_assert(index == 0 || index == 1); + return std::get<75 + index>(this->_data); + } + // clang-format on + + std::vector get_wires() override + { + return { + q_transcript_add, + q_transcript_mul, + q_transcript_eq, + q_transcript_accumulate, + q_transcript_msm_transition, + transcript_pc, + transcript_msm_count, + transcript_x, + transcript_y, + transcript_z1, + transcript_z2, + transcript_z1zero, + transcript_z2zero, + transcript_op, + transcript_accumulator_x, + transcript_accumulator_y, + transcript_msm_x, + transcript_msm_y, + table_pc, + table_point_transition, + table_round, + table_scalar_sum, + table_s1, + table_s2, + table_s3, + table_s4, + table_s5, + table_s6, + table_s7, + table_s8, + table_skew, + table_dx, + table_dy, + table_tx, + table_ty, + q_msm_transition, + msm_q_add, + msm_q_double, + msm_q_skew, + msm_accumulator_x, + msm_accumulator_y, + msm_pc, + msm_size_of_msm, + msm_count, + msm_round, + msm_q_add1, + msm_q_add2, + msm_q_add3, + msm_q_add4, + msm_x1, + msm_y1, + msm_x2, + msm_y2, + msm_x3, + msm_y3, + msm_x4, + msm_y4, + msm_collision_x1, + msm_collision_x2, + msm_collision_x3, + msm_collision_x4, + msm_lambda1, + msm_lambda2, + msm_lambda3, + msm_lambda4, + msm_slice1, + msm_slice2, + msm_slice3, + msm_slice4, + transcript_accumulator_empty, + transcript_q_reset_accumulator, + q_wnaf, + lookup_read_counts_0, + lookup_read_counts_1, + }; + }; + // Gemini-specific getters. + std::vector get_unshifted() override + { + return { + lagrange_first, + lagrange_second, + lagrange_last, + q_transcript_add, + q_transcript_eq, + q_transcript_msm_transition, + transcript_x, + transcript_y, + transcript_z1, + transcript_z2, + transcript_z1zero, + transcript_z2zero, + transcript_op, + transcript_msm_x, + transcript_msm_y, + table_point_transition, + table_s1, + table_s2, + table_s3, + table_s4, + table_s5, + table_s6, + table_s7, + table_s8, + table_skew, + msm_size_of_msm, + msm_q_add2, + msm_q_add3, + msm_q_add4, + msm_x1, + msm_y1, + msm_x2, + msm_y2, + msm_x3, + msm_y3, + msm_x4, + msm_y4, + msm_collision_x1, + msm_collision_x2, + msm_collision_x3, + msm_collision_x4, + msm_lambda1, + msm_lambda2, + msm_lambda3, + msm_lambda4, + msm_slice1, + msm_slice2, + msm_slice3, + msm_slice4, + transcript_q_reset_accumulator, + lookup_read_counts_0, + lookup_read_counts_1, + lookup_inverses, + }; + }; + + std::vector get_to_be_shifted() override + { + return { + q_transcript_mul, + q_transcript_accumulate, // NOT USED + transcript_msm_count, + transcript_accumulator_x, + transcript_accumulator_y, + table_scalar_sum, + table_dx, + table_dy, + table_tx, + table_ty, + q_msm_transition, + msm_q_add, + msm_q_double, + msm_q_skew, + msm_accumulator_x, + msm_accumulator_y, + msm_count, + msm_round, + msm_q_add1, + msm_pc, + table_pc, + transcript_pc, + table_round, + transcript_accumulator_empty, + q_wnaf, + z_perm, + }; + }; + std::vector get_shifted() override + { + return { + q_transcript_mul_shift, + q_transcript_accumulate_shift, + transcript_msm_count_shift, + transcript_accumulator_x_shift, + transcript_accumulator_y_shift, + table_scalar_sum_shift, + table_dx_shift, + table_dy_shift, + table_tx_shift, + table_ty_shift, + q_msm_transition_shift, + msm_q_add_shift, + msm_q_double_shift, + msm_q_skew_shift, + msm_accumulator_x_shift, + msm_accumulator_y_shift, + msm_count_shift, + msm_round_shift, + msm_q_add1_shift, + msm_pc_shift, + table_pc_shift, + transcript_pc_shift, + table_round_shift, + transcript_accumulator_empty_shift, + q_wnaf_shift, + z_perm_shift, + }; + }; + + AllEntities() = default; + + AllEntities(const AllEntities& other) + : AllEntities_(other){}; + + AllEntities(AllEntities&& other) noexcept + : AllEntities_(other){}; + + AllEntities& operator=(const AllEntities& other) + { + if (this == &other) { + return *this; + } + AllEntities_::operator=(other); + return *this; + } + + AllEntities& operator=(AllEntities&& other) noexcept + { + AllEntities_::operator=(other); + return *this; + } + + ~AllEntities() override = default; + }; + + public: + /** + * @brief The proving key is responsible for storing the polynomials used by the prover. + * @note TODO(Cody): Maybe multiple inheritance is the right thing here. In that case, nothing should eve inherit + * from ProvingKey. + */ + class ProvingKey : public ProvingKey_, + WitnessEntities> { + public: + // Expose constructors on the base class + using Base = ProvingKey_, + WitnessEntities>; + using Base::Base; + + // The plookup wires that store plookup read data. + std::array get_table_column_wires() { return {}; }; + }; + + /** + * @brief The verification key is responsible for storing the the commitments to the precomputed (non-witnessk) + * polynomials used by the verifier. + * + * @note Note the discrepancy with what sort of data is stored here vs in the proving key. We may want to resolve + * that, and split out separate PrecomputedPolynomials/Commitments data for clarity but also for portability of our + * circuits. + */ + using VerificationKey = VerificationKey_>; + + /** + * @brief A container for polynomials handles; only stores spans. + */ + using ProverPolynomials = AllEntities; + + /** + * @brief A container for polynomials produced after the first round of sumcheck. + * @todo TODO(#394) Use polynomial classes for guaranteed memory alignment. + */ + using FoldedPolynomials = AllEntities, PolynomialHandle>; + + using RawPolynomials = AllEntities; + + /** + * @brief A container for polynomials produced after the first round of sumcheck. + * @todo TODO(#394) Use polynomial classes for guaranteed memory alignment. + */ + using RowPolynomials = AllEntities; + + /** + * @brief A container for storing the partially evaluated multivariates produced by sumcheck. + */ + class PartiallyEvaluatedMultivariates : public AllEntities { + + public: + PartiallyEvaluatedMultivariates() = default; + PartiallyEvaluatedMultivariates(const size_t circuit_size) + { + // Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2) + for (auto& poly : this->_data) { + poly = Polynomial(circuit_size / 2); + } + } + }; + + /** + * @brief A container for univariates produced during the hot loop in sumcheck. + * @todo TODO(#390): Simplify this by moving MAX_RELATION_LENGTH? + */ + template + using ExtendedEdges = + AllEntities, sumcheck::Univariate>; + + /** + * @brief A container for the polynomials evaluations produced during sumcheck, which are purported to be the + * evaluations of polynomials committed in earlier rounds. + */ + class ClaimedEvaluations : public AllEntities { + public: + using Base = AllEntities; + using Base::Base; + ClaimedEvaluations(std::array _data_in) { this->_data = _data_in; } + }; + + /** + * @brief A container for commitment labels. + * @note It's debatable whether this should inherit from AllEntities. since most entries are not strictly needed. It + * has, however, been useful during debugging to have these labels available. + * + */ + class CommitmentLabels : public AllEntities { + private: + using Base = AllEntities; + + public: + CommitmentLabels() + : AllEntities() + { + Base::q_transcript_add = "Q_TRANSCRIPT_ADD"; + Base::q_transcript_mul = "Q_TRANSCRIPT_MUL"; + Base::q_transcript_eq = "Q_TRANSCRIPT_EQ"; + Base::q_transcript_accumulate = "Q_TRANSCRIPT_ACCUMULATE"; + Base::q_transcript_msm_transition = "Q_TRANSCRIPT_MSM_TRANSITION"; + Base::transcript_pc = "TRANSCRIPT_PC"; + Base::transcript_msm_count = "TRANSCRIPT_MSM_COUNT"; + Base::transcript_x = "TRANSCRIPT_X"; + Base::transcript_y = "TRANSCRIPT_Y"; + Base::transcript_z1 = "TRANSCRIPT_Z1"; + Base::transcript_z2 = "TRANSCRIPT_Z2"; + Base::transcript_z1zero = "TRANSCRIPT_Z1ZERO"; + Base::transcript_z2zero = "TRANSCRIPT_Z2ZERO"; + Base::transcript_op = "TRANSCRIPT_OP"; + Base::transcript_accumulator_x = "TRANSCRIPT_ACCUMULATOR_X"; + Base::transcript_accumulator_y = "TRANSCRIPT_ACCUMULATOR_Y"; + Base::transcript_msm_x = "TRANSCRIPT_MSM_X"; + Base::transcript_msm_y = "TRANSCRIPT_MSM_Y"; + Base::table_pc = "TABLE_PC"; + Base::table_point_transition = "TABLE_POINT_TRANSITION"; + Base::table_round = "TABLE_ROUND"; + Base::table_scalar_sum = "TABLE_SCALAR_SUM"; + Base::table_s1 = "TABLE_S1"; + Base::table_s2 = "TABLE_S2"; + Base::table_s3 = "TABLE_S3"; + Base::table_s4 = "TABLE_S4"; + Base::table_s5 = "TABLE_S5"; + Base::table_s6 = "TABLE_S6"; + Base::table_s7 = "TABLE_S7"; + Base::table_s8 = "TABLE_S8"; + Base::table_skew = "TABLE_SKEW"; + Base::table_dx = "TABLE_DX"; + Base::table_dy = "TABLE_DY"; + Base::table_tx = "TABLE_TX"; + Base::table_ty = "TABLE_TY"; + Base::q_msm_transition = "Q_MSM_TRANSITION"; + Base::msm_q_add = "MSM_Q_ADD"; + Base::msm_q_double = "MSM_Q_DOUBLE"; + Base::msm_q_skew = "MSM_Q_SKEW"; + Base::msm_accumulator_x = "MSM_ACCUMULATOR_X"; + Base::msm_accumulator_y = "MSM_ACCUMULATOR_Y"; + Base::msm_pc = "MSM_PC"; + Base::msm_size_of_msm = "MSM_SIZE_OF_MSM"; + Base::msm_count = "MSM_COUNT"; + Base::msm_round = "MSM_ROUND"; + Base::msm_q_add1 = "MSM_Q_ADD1"; + Base::msm_q_add2 = "MSM_Q_ADD2"; + Base::msm_q_add3 = "MSM_Q_ADD3"; + Base::msm_q_add4 = "MSM_Q_ADD4"; + Base::msm_x1 = "MSM_X1"; + Base::msm_y1 = "MSM_Y1"; + Base::msm_x2 = "MSM_X2"; + Base::msm_y2 = "MSM_Y2"; + Base::msm_x3 = "MSM_X3"; + Base::msm_y3 = "MSM_Y3"; + Base::msm_x4 = "MSM_X4"; + Base::msm_y4 = "MSM_Y4"; + Base::msm_collision_x1 = "MSM_COLLISION_X1"; + Base::msm_collision_x2 = "MSM_COLLISION_X2"; + Base::msm_collision_x3 = "MSM_COLLISION_X3"; + Base::msm_collision_x4 = "MSM_COLLISION_X4"; + Base::msm_lambda1 = "MSM_LAMBDA1"; + Base::msm_lambda2 = "MSM_LAMBDA2"; + Base::msm_lambda3 = "MSM_LAMBDA3"; + Base::msm_lambda4 = "MSM_LAMBDA4"; + Base::msm_slice1 = "MSM_SLICE1"; + Base::msm_slice2 = "MSM_SLICE2"; + Base::msm_slice3 = "MSM_SLICE3"; + Base::msm_slice4 = "MSM_SLICE4"; + Base::transcript_accumulator_empty = "TRANSCRIPT_ACCUMULATOR_EMPTY"; + Base::transcript_q_reset_accumulator = "TRANSCRIPT_Q_RESET_ACCUMULATOR"; + Base::q_wnaf = "Q_WNAF"; + Base::lookup_read_counts_0 = "LOOKUP_READ_COUNTS_0"; + Base::lookup_read_counts_1 = "LOOKUP_READ_COUNTS_1"; + Base::z_perm = "Z_PERM"; + Base::lookup_inverses = "LOOKUP_INVERSES"; + // The ones beginning with "__" are only used for debugging + Base::lagrange_first = "__LAGRANGE_FIRST"; + Base::lagrange_second = "__LAGRANGE_SECOND"; + Base::lagrange_last = "__LAGRANGE_LAST"; + }; + }; + + class VerifierCommitments : public AllEntities { + private: + using Base = AllEntities; + + public: + VerifierCommitments(const std::shared_ptr& verification_key, + const VerifierTranscript& transcript) + { + static_cast(transcript); + Base::lagrange_first = verification_key->lagrange_first; + Base::lagrange_second = verification_key->lagrange_second; + Base::lagrange_last = verification_key->lagrange_last; + } + }; +}; + +class ECCVM : public ECCVMBase {}; +// not actually grumpkin, need to finish supporting grumpkin in ipa +class ECCVMGrumpkin : public ECCVMBase {}; + +// NOLINTEND(cppcoreguidelines-avoid-const-or-ref-data-members) + +} // namespace flavor +namespace sumcheck { + +extern template class ECCVMTranscriptRelationBase; +extern template class ECCVMWnafRelationBase; +extern template class ECCVMPointTableRelationBase; +extern template class ECCVMMSMRelationBase; +extern template class ECCVMSetRelationBase; +extern template class ECCVMLookupRelationBase; + +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMTranscriptRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMWnafRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMPointTableRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMMSMRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMSetRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVM); + +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMTranscriptRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMWnafRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMPointTableRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMMSMRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMSetRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVMGrumpkin); + +DECLARE_SUMCHECK_PERMUTATION_CLASS(ECCVMSetRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_PERMUTATION_CLASS(ECCVMSetRelationBase, flavor::ECCVMGrumpkin); +} // namespace sumcheck +} // namespace proof_system::honk \ No newline at end of file diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_prover.cpp b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.cpp new file mode 100644 index 0000000000..fe35e6db09 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.cpp @@ -0,0 +1,386 @@ +#include "eccvm_prover.hpp" +#include +#include +#include "barretenberg/honk/proof_system/prover_library.hpp" +#include "barretenberg/honk/proof_system/lookup_library.hpp" +#include "barretenberg/honk/proof_system/permutation_library.hpp" +#include "barretenberg/honk/sumcheck/relations/lookup_relation.hpp" +#include "barretenberg/honk/sumcheck/relations/permutation_relation.hpp" +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include +#include "barretenberg/honk/sumcheck/polynomials/univariate.hpp" // will go away +#include "barretenberg/honk/utils/power_polynomial.hpp" +#include "barretenberg/honk/pcs/commitment_key.hpp" +#include +#include +#include +#include +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/transcript/transcript_wrappers.hpp" +#include +#include "barretenberg/honk/pcs/claim.hpp" + +namespace proof_system::honk { + +/** + * Create ECCVMProver_ from proving key, witness and manifest. + * + * @param input_key Proving key. + * @param input_manifest Input manifest + * + * @tparam settings Settings class. + * */ +template +ECCVMProver_::ECCVMProver_(std::shared_ptr input_key, + std::shared_ptr commitment_key) + : key(input_key) + , queue(commitment_key, transcript) + , pcs_commitment_key(commitment_key) +{ + + // TODO(@zac-williamson) is there a cleaner way of doing this? + prover_polynomials.q_transcript_add = key->q_transcript_add; + prover_polynomials.q_transcript_mul = key->q_transcript_mul; + prover_polynomials.q_transcript_eq = key->q_transcript_eq; + prover_polynomials.q_transcript_accumulate = key->q_transcript_accumulate; + prover_polynomials.q_transcript_msm_transition = key->q_transcript_msm_transition; + prover_polynomials.transcript_pc = key->transcript_pc; + prover_polynomials.transcript_msm_count = key->transcript_msm_count; + prover_polynomials.transcript_x = key->transcript_x; + prover_polynomials.transcript_y = key->transcript_y; + prover_polynomials.transcript_z1 = key->transcript_z1; + prover_polynomials.transcript_z2 = key->transcript_z2; + prover_polynomials.transcript_z1zero = key->transcript_z1zero; + prover_polynomials.transcript_z2zero = key->transcript_z2zero; + prover_polynomials.transcript_op = key->transcript_op; + prover_polynomials.transcript_accumulator_x = key->transcript_accumulator_x; + prover_polynomials.transcript_accumulator_y = key->transcript_accumulator_y; + prover_polynomials.transcript_msm_x = key->transcript_msm_x; + prover_polynomials.transcript_msm_y = key->transcript_msm_y; + prover_polynomials.table_pc = key->table_pc; + prover_polynomials.table_point_transition = key->table_point_transition; + prover_polynomials.table_round = key->table_round; + prover_polynomials.table_scalar_sum = key->table_scalar_sum; + prover_polynomials.table_s1 = key->table_s1; + prover_polynomials.table_s2 = key->table_s2; + prover_polynomials.table_s3 = key->table_s3; + prover_polynomials.table_s4 = key->table_s4; + prover_polynomials.table_s5 = key->table_s5; + prover_polynomials.table_s6 = key->table_s6; + prover_polynomials.table_s7 = key->table_s7; + prover_polynomials.table_s8 = key->table_s8; + prover_polynomials.table_skew = key->table_skew; + prover_polynomials.table_dx = key->table_dx; + prover_polynomials.table_dy = key->table_dy; + prover_polynomials.table_tx = key->table_tx; + prover_polynomials.table_ty = key->table_ty; + prover_polynomials.q_msm_transition = key->q_msm_transition; + prover_polynomials.msm_q_add = key->msm_q_add; + prover_polynomials.msm_q_double = key->msm_q_double; + prover_polynomials.msm_q_skew = key->msm_q_skew; + prover_polynomials.msm_accumulator_x = key->msm_accumulator_x; + prover_polynomials.msm_accumulator_y = key->msm_accumulator_y; + prover_polynomials.msm_pc = key->msm_pc; + prover_polynomials.msm_size_of_msm = key->msm_size_of_msm; + prover_polynomials.msm_count = key->msm_count; + prover_polynomials.msm_round = key->msm_round; + prover_polynomials.msm_q_add1 = key->msm_q_add1; + prover_polynomials.msm_q_add2 = key->msm_q_add2; + prover_polynomials.msm_q_add3 = key->msm_q_add3; + prover_polynomials.msm_q_add4 = key->msm_q_add4; + prover_polynomials.msm_x1 = key->msm_x1; + prover_polynomials.msm_y1 = key->msm_y1; + prover_polynomials.msm_x2 = key->msm_x2; + prover_polynomials.msm_y2 = key->msm_y2; + prover_polynomials.msm_x3 = key->msm_x3; + prover_polynomials.msm_y3 = key->msm_y3; + prover_polynomials.msm_x4 = key->msm_x4; + prover_polynomials.msm_y4 = key->msm_y4; + prover_polynomials.msm_collision_x1 = key->msm_collision_x1; + prover_polynomials.msm_collision_x2 = key->msm_collision_x2; + prover_polynomials.msm_collision_x3 = key->msm_collision_x3; + prover_polynomials.msm_collision_x4 = key->msm_collision_x4; + prover_polynomials.msm_lambda1 = key->msm_lambda1; + prover_polynomials.msm_lambda2 = key->msm_lambda2; + prover_polynomials.msm_lambda3 = key->msm_lambda3; + prover_polynomials.msm_lambda4 = key->msm_lambda4; + prover_polynomials.msm_slice1 = key->msm_slice1; + prover_polynomials.msm_slice2 = key->msm_slice2; + prover_polynomials.msm_slice3 = key->msm_slice3; + prover_polynomials.msm_slice4 = key->msm_slice4; + prover_polynomials.transcript_accumulator_empty = key->transcript_accumulator_empty; + prover_polynomials.transcript_q_reset_accumulator = key->transcript_q_reset_accumulator; + prover_polynomials.q_wnaf = key->q_wnaf; + prover_polynomials.lookup_read_counts_0 = key->lookup_read_counts_0; + prover_polynomials.lookup_read_counts_1 = key->lookup_read_counts_1; + prover_polynomials.q_transcript_mul_shift = key->q_transcript_mul.shifted(); + prover_polynomials.q_transcript_accumulate_shift = key->q_transcript_accumulate.shifted(); + prover_polynomials.transcript_msm_count_shift = key->transcript_msm_count.shifted(); + prover_polynomials.transcript_accumulator_x_shift = key->transcript_accumulator_x.shifted(); + prover_polynomials.transcript_accumulator_y_shift = key->transcript_accumulator_y.shifted(); + prover_polynomials.table_scalar_sum_shift = key->table_scalar_sum.shifted(); + prover_polynomials.table_dx_shift = key->table_dx.shifted(); + prover_polynomials.table_dy_shift = key->table_dy.shifted(); + prover_polynomials.table_tx_shift = key->table_tx.shifted(); + prover_polynomials.table_ty_shift = key->table_ty.shifted(); + prover_polynomials.q_msm_transition_shift = key->q_msm_transition.shifted(); + prover_polynomials.msm_q_add_shift = key->msm_q_add.shifted(); + prover_polynomials.msm_q_double_shift = key->msm_q_double.shifted(); + prover_polynomials.msm_q_skew_shift = key->msm_q_skew.shifted(); + prover_polynomials.msm_accumulator_x_shift = key->msm_accumulator_x.shifted(); + prover_polynomials.msm_accumulator_y_shift = key->msm_accumulator_y.shifted(); + prover_polynomials.msm_count_shift = key->msm_count.shifted(); + prover_polynomials.msm_round_shift = key->msm_round.shifted(); + prover_polynomials.msm_q_add1_shift = key->msm_q_add1.shifted(); + prover_polynomials.msm_pc_shift = key->msm_pc.shifted(); + prover_polynomials.table_pc_shift = key->table_pc.shifted(); + prover_polynomials.transcript_pc_shift = key->transcript_pc.shifted(); + prover_polynomials.table_round_shift = key->table_round.shifted(); + prover_polynomials.transcript_accumulator_empty_shift = key->transcript_accumulator_empty.shifted(); + prover_polynomials.q_wnaf_shift = key->q_wnaf.shifted(); + prover_polynomials.lagrange_first = key->lagrange_first; + prover_polynomials.lagrange_second = key->lagrange_second; + prover_polynomials.lagrange_last = key->lagrange_last; + + prover_polynomials.lookup_inverses = key->lookup_inverses; + key->z_perm = Polynomial(key->circuit_size); + prover_polynomials.z_perm = key->z_perm; +} + +/** + * @brief Commit to the first three wires only + * + */ +template void ECCVMProver_::compute_wire_commitments() +{ + auto wire_polys = key->get_wires(); + auto labels = commitment_labels.get_wires(); + for (size_t idx = 0; idx < wire_polys.size(); ++idx) { + queue.add_commitment(wire_polys[idx], labels[idx]); + } +} + +/** + * @brief Add circuit size, public input size, and public inputs to transcript + * + */ +template void ECCVMProver_::execute_preamble_round() +{ + const auto circuit_size = static_cast(key->circuit_size); + const auto num_public_inputs = static_cast(key->num_public_inputs); + + transcript.send_to_verifier("circuit_size", circuit_size); +} + +/** + * @brief Compute commitments to the first three wires + * + */ +template void ECCVMProver_::execute_wire_commitments_round() +{ + auto wire_polys = key->get_wires(); + auto labels = commitment_labels.get_wires(); + for (size_t idx = 0; idx < wire_polys.size(); ++idx) { + queue.add_commitment(wire_polys[idx], labels[idx]); + } +} + +/** + * @brief Compute sorted witness-table accumulator + * + */ +template void ECCVMProver_::execute_log_derivative_commitments_round() +{ + // Compute and add eta to relation parameters + auto [eta, gamma] = transcript.get_challenges("beta", "gamma"); + // TODO(#583)(@zac-williamson): fix Transcript to be able to generate more than 2 challenges per round! oof. + auto eta_sqr = eta * eta; + relation_parameters.gamma = gamma; + relation_parameters.eta = eta; + relation_parameters.eta_sqr = eta_sqr; + relation_parameters.eta_cube = eta_sqr * eta; + relation_parameters.permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + relation_parameters.permutation_offset = relation_parameters.permutation_offset.invert(); + // Compute inverse polynomial for our logarithmic-derivative lookup method + lookup_library::compute_logderivative_inverse( + prover_polynomials, relation_parameters, key->circuit_size); + queue.add_commitment(key->lookup_inverses, commitment_labels.lookup_inverses); + prover_polynomials.lookup_inverses = key->lookup_inverses; +} + +/** + * @brief Compute permutation and lookup grand product polynomials and commitments + * + */ +template void ECCVMProver_::execute_grand_product_computation_round() +{ + // Compute permutation grand product and their commitments + permutation_library::compute_permutation_grand_products(key, prover_polynomials, relation_parameters); + + queue.add_commitment(key->z_perm, commitment_labels.z_perm); +} + +/** + * @brief Run Sumcheck resulting in u = (u_1,...,u_d) challenges and all evaluations at u being calculated. + * + */ +template void ECCVMProver_::execute_relation_check_rounds() +{ + using Sumcheck = sumcheck::Sumcheck>; + + auto sumcheck = Sumcheck(key->circuit_size, transcript); + + sumcheck_output = sumcheck.execute_prover(prover_polynomials, relation_parameters); +} + +/** + * - Get rho challenge + * - Compute d+1 Fold polynomials and their evaluations. + * + * */ +template void ECCVMProver_::execute_univariatization_round() +{ + const size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES; + + // Generate batching challenge ρ and powers 1,ρ,…,ρᵐ⁻¹ + FF rho = transcript.get_challenge("rho"); + std::vector rhos = Gemini::powers_of_rho(rho, NUM_POLYNOMIALS); + + // Batch the unshifted polynomials and the to-be-shifted polynomials using ρ + Polynomial batched_poly_unshifted(key->circuit_size); // batched unshifted polynomials + size_t poly_idx = 0; // TODO(#391) zip + for (auto& unshifted_poly : prover_polynomials.get_unshifted()) { + batched_poly_unshifted.add_scaled(unshifted_poly, rhos[poly_idx]); + ++poly_idx; + } + + Polynomial batched_poly_to_be_shifted(key->circuit_size); // batched to-be-shifted polynomials + for (auto& to_be_shifted_poly : prover_polynomials.get_to_be_shifted()) { + batched_poly_to_be_shifted.add_scaled(to_be_shifted_poly, rhos[poly_idx]); + ++poly_idx; + }; + + // Compute d-1 polynomials Fold^(i), i = 1, ..., d-1. + fold_polynomials = Gemini::compute_fold_polynomials( + sumcheck_output.challenge_point, std::move(batched_poly_unshifted), std::move(batched_poly_to_be_shifted)); + + // Compute and add to trasnscript the commitments [Fold^(i)], i = 1, ..., d-1 + for (size_t l = 0; l < key->log_circuit_size - 1; ++l) { + queue.add_commitment(fold_polynomials[l + 2], "Gemini:FOLD_" + std::to_string(l + 1)); + } +} + +/** + * - Do Fiat-Shamir to get "r" challenge + * - Compute remaining two partially evaluated Fold polynomials Fold_{r}^(0) and Fold_{-r}^(0). + * - Compute and aggregate opening pairs (challenge, evaluation) for each of d Fold polynomials. + * - Add d-many Fold evaluations a_i, i = 0, ..., d-1 to the transcript, excluding eval of Fold_{r}^(0) + * */ +template void ECCVMProver_::execute_pcs_evaluation_round() +{ + const FF r_challenge = transcript.get_challenge("Gemini:r"); + gemini_output = Gemini::compute_fold_polynomial_evaluations( + sumcheck_output.challenge_point, std::move(fold_polynomials), r_challenge); + + for (size_t l = 0; l < key->log_circuit_size; ++l) { + std::string label = "Gemini:a_" + std::to_string(l); + const auto& evaluation = gemini_output.opening_pairs[l + 1].evaluation; + transcript.send_to_verifier(label, evaluation); + } +} + +/** + * - Do Fiat-Shamir to get "nu" challenge. + * - Compute commitment [Q]_1 + * */ +template void ECCVMProver_::execute_shplonk_batched_quotient_round() +{ + nu_challenge = transcript.get_challenge("Shplonk:nu"); + + batched_quotient_Q = + Shplonk::compute_batched_quotient(gemini_output.opening_pairs, gemini_output.witnesses, nu_challenge); + + // commit to Q(X) and add [Q] to the transcript + queue.add_commitment(batched_quotient_Q, "Shplonk:Q"); +} + +/** + * - Do Fiat-Shamir to get "z" challenge. + * - Compute polynomial Q(X) - Q_z(X) + * */ +template void ECCVMProver_::execute_shplonk_partial_evaluation_round() +{ + const FF z_challenge = transcript.get_challenge("Shplonk:z"); + + shplonk_output = Shplonk::compute_partially_evaluated_batched_quotient( + gemini_output.opening_pairs, gemini_output.witnesses, std::move(batched_quotient_Q), nu_challenge, z_challenge); +} +/** + * - Compute final PCS opening proof: + * - For KZG, this is the quotient commitment [W]_1 + * - For IPA, the vectors L and R + * */ +template void ECCVMProver_::execute_final_pcs_round() +{ + PCS::compute_opening_proof(pcs_commitment_key, shplonk_output.opening_pair, shplonk_output.witness, transcript); + // queue.add_commitment(quotient_W, "KZG:W"); +} + +template plonk::proof& ECCVMProver_::export_proof() +{ + proof.proof_data = transcript.proof_data; + return proof; +} + +template plonk::proof& ECCVMProver_::construct_proof() +{ + // Add circuit size public input size and public inputs to transcript. + execute_preamble_round(); + + // Compute first three wire commitments + execute_wire_commitments_round(); + queue.process_queue(); + + // Compute sorted list accumulator and commitment + execute_log_derivative_commitments_round(); + queue.process_queue(); + + // Fiat-Shamir: beta & gamma + // Compute grand product(s) and commitments. + execute_grand_product_computation_round(); + queue.process_queue(); + + // Fiat-Shamir: alpha + // Run sumcheck subprotocol. + execute_relation_check_rounds(); + + // Fiat-Shamir: rho + // Compute Fold polynomials and their commitments. + execute_univariatization_round(); + queue.process_queue(); + + // Fiat-Shamir: r + // Compute Fold evaluations + execute_pcs_evaluation_round(); + + // Fiat-Shamir: nu + // Compute Shplonk batched quotient commitment Q + execute_shplonk_batched_quotient_round(); + queue.process_queue(); + + // Fiat-Shamir: z + // Compute partial evaluation Q_z + execute_shplonk_partial_evaluation_round(); + + // Fiat-Shamir: z + // Compute PCS opening proof (either KZG quotient commitment or IPA opening proof) + execute_final_pcs_round(); + + return export_proof(); +} + +template class ECCVMProver_; +template class ECCVMProver_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_prover.hpp b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.hpp new file mode 100644 index 0000000000..728e1f4acc --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.hpp @@ -0,0 +1,85 @@ +#pragma once +#include "barretenberg/honk/proof_system/work_queue.hpp" +#include "barretenberg/plonk/proof_system/types/proof.hpp" +#include "barretenberg/honk/pcs/gemini/gemini.hpp" +#include "barretenberg/honk/pcs/shplonk/shplonk_single.hpp" +#include "barretenberg/honk/transcript/transcript.hpp" +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "barretenberg/honk/sumcheck/sumcheck_output.hpp" + +namespace proof_system::honk { + +// We won't compile this class with honk::flavor::Standard, but we will like want to compile it (at least for testing) +// with a flavor that uses the curve Grumpkin, or a flavor that does/does not have zk, etc. +template class ECCVMProver_ { + + using FF = typename Flavor::FF; + using PCSParams = typename Flavor::PCSParams; + using PCS = typename Flavor::PCS; + using PCSCommitmentKey = typename Flavor::PCSParams::CommitmentKey; + using ProvingKey = typename Flavor::ProvingKey; + using Polynomial = typename Flavor::Polynomial; + using ProverPolynomials = typename Flavor::ProverPolynomials; + using CommitmentLabels = typename Flavor::CommitmentLabels; + + public: + explicit ECCVMProver_(std::shared_ptr input_key, std::shared_ptr commitment_key); + + void execute_preamble_round(); + void execute_wire_commitments_round(); + void execute_log_derivative_commitments_round(); + void execute_grand_product_computation_round(); + void execute_relation_check_rounds(); + void execute_univariatization_round(); + void execute_pcs_evaluation_round(); + void execute_shplonk_batched_quotient_round(); + void execute_shplonk_partial_evaluation_round(); + void execute_final_pcs_round(); + + void compute_wire_commitments(); + + plonk::proof& export_proof(); + plonk::proof& construct_proof(); + + ProverTranscript transcript; + + std::vector public_inputs; + + sumcheck::RelationParameters relation_parameters; + + std::shared_ptr key; + + // Container for spans of all polynomials required by the prover (i.e. all multivariates evaluated by Sumcheck). + ProverPolynomials prover_polynomials; + + CommitmentLabels commitment_labels; + + // Container for d + 1 Fold polynomials produced by Gemini + std::vector fold_polynomials; + + Polynomial batched_quotient_Q; // batched quotient poly computed by Shplonk + FF nu_challenge; // needed in both Shplonk rounds + + Polynomial quotient_W; + + work_queue queue; + + sumcheck::SumcheckOutput sumcheck_output; + pcs::gemini::ProverOutput gemini_output; + pcs::shplonk::ProverOutput shplonk_output; + std::shared_ptr pcs_commitment_key; + + using Gemini = pcs::gemini::MultilinearReductionScheme; + using Shplonk = pcs::shplonk::SingleBatchOpeningScheme; + + private: + plonk::proof proof; +}; + +extern template class ECCVMProver_; +extern template class ECCVMProver_; + +using ECCVMProver = ECCVMProver_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.cpp b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.cpp new file mode 100644 index 0000000000..f3accf7631 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.cpp @@ -0,0 +1,256 @@ +#include "./eccvm_verifier.hpp" +#include "barretenberg/honk/transcript/transcript.hpp" +#include "barretenberg/numeric/bitop/get_msb.hpp" +#include "barretenberg/honk/flavor/standard.hpp" +#include "barretenberg/honk/utils/power_polynomial.hpp" + +using namespace barretenberg; +using namespace proof_system::honk::sumcheck; + +namespace proof_system::honk { +template +ECCVMVerifier_::ECCVMVerifier_(std::shared_ptr verifier_key) + : key(verifier_key) +{} + +template +ECCVMVerifier_::ECCVMVerifier_(ECCVMVerifier_&& other) noexcept + : key(std::move(other.key)) + , pcs_verification_key(std::move(other.pcs_verification_key)) +{} + +template ECCVMVerifier_& ECCVMVerifier_::operator=(ECCVMVerifier_&& other) noexcept +{ + key = other.key; + pcs_verification_key = (std::move(other.pcs_verification_key)); + commitments.clear(); + pcs_fr_elements.clear(); + return *this; +} + +/** + * @brief This function verifies an ECCVM Honk proof for given program settings. + * + */ +template bool ECCVMVerifier_::verify_proof(const plonk::proof& proof) +{ + using FF = typename Flavor::FF; + using GroupElement = typename Flavor::GroupElement; + using Commitment = typename Flavor::Commitment; + using PCSParams = typename Flavor::PCSParams; + using PCS = typename Flavor::PCS; + using Gemini = pcs::gemini::MultilinearReductionScheme; + using Shplonk = pcs::shplonk::SingleBatchOpeningScheme; + using VerifierCommitments = typename Flavor::VerifierCommitments; + using CommitmentLabels = typename Flavor::CommitmentLabels; + + RelationParameters relation_parameters; + + transcript = VerifierTranscript{ proof.proof_data }; + + auto commitments = VerifierCommitments(key, transcript); + auto commitment_labels = CommitmentLabels(); + + // TODO(Adrian): Change the initialization of the transcript to take the VK hash? + const auto circuit_size = transcript.template receive_from_prover("circuit_size"); + + if (circuit_size != key->circuit_size) { + return false; + } + + // Get commitments to VM wires + commitments.q_transcript_add = + transcript.template receive_from_prover(commitment_labels.q_transcript_add); + commitments.q_transcript_mul = + transcript.template receive_from_prover(commitment_labels.q_transcript_mul); + commitments.q_transcript_eq = + transcript.template receive_from_prover(commitment_labels.q_transcript_eq); + commitments.q_transcript_accumulate = + transcript.template receive_from_prover(commitment_labels.q_transcript_accumulate); + commitments.q_transcript_msm_transition = + transcript.template receive_from_prover(commitment_labels.q_transcript_msm_transition); + commitments.transcript_pc = transcript.template receive_from_prover(commitment_labels.transcript_pc); + commitments.transcript_msm_count = + transcript.template receive_from_prover(commitment_labels.transcript_msm_count); + commitments.transcript_x = transcript.template receive_from_prover(commitment_labels.transcript_x); + commitments.transcript_y = transcript.template receive_from_prover(commitment_labels.transcript_y); + commitments.transcript_z1 = transcript.template receive_from_prover(commitment_labels.transcript_z1); + commitments.transcript_z2 = transcript.template receive_from_prover(commitment_labels.transcript_z2); + commitments.transcript_z1zero = + transcript.template receive_from_prover(commitment_labels.transcript_z1zero); + commitments.transcript_z2zero = + transcript.template receive_from_prover(commitment_labels.transcript_z2zero); + commitments.transcript_op = transcript.template receive_from_prover(commitment_labels.transcript_op); + commitments.transcript_accumulator_x = + transcript.template receive_from_prover(commitment_labels.transcript_accumulator_x); + commitments.transcript_accumulator_y = + transcript.template receive_from_prover(commitment_labels.transcript_accumulator_y); + commitments.transcript_msm_x = + transcript.template receive_from_prover(commitment_labels.transcript_msm_x); + commitments.transcript_msm_y = + transcript.template receive_from_prover(commitment_labels.transcript_msm_y); + commitments.table_pc = transcript.template receive_from_prover(commitment_labels.table_pc); + commitments.table_point_transition = + transcript.template receive_from_prover(commitment_labels.table_point_transition); + commitments.table_round = transcript.template receive_from_prover(commitment_labels.table_round); + commitments.table_scalar_sum = + transcript.template receive_from_prover(commitment_labels.table_scalar_sum); + commitments.table_s1 = transcript.template receive_from_prover(commitment_labels.table_s1); + commitments.table_s2 = transcript.template receive_from_prover(commitment_labels.table_s2); + commitments.table_s3 = transcript.template receive_from_prover(commitment_labels.table_s3); + commitments.table_s4 = transcript.template receive_from_prover(commitment_labels.table_s4); + commitments.table_s5 = transcript.template receive_from_prover(commitment_labels.table_s5); + commitments.table_s6 = transcript.template receive_from_prover(commitment_labels.table_s6); + commitments.table_s7 = transcript.template receive_from_prover(commitment_labels.table_s7); + commitments.table_s8 = transcript.template receive_from_prover(commitment_labels.table_s8); + commitments.table_skew = transcript.template receive_from_prover(commitment_labels.table_skew); + commitments.table_dx = transcript.template receive_from_prover(commitment_labels.table_dx); + commitments.table_dy = transcript.template receive_from_prover(commitment_labels.table_dy); + commitments.table_tx = transcript.template receive_from_prover(commitment_labels.table_tx); + commitments.table_ty = transcript.template receive_from_prover(commitment_labels.table_ty); + commitments.q_msm_transition = + transcript.template receive_from_prover(commitment_labels.q_msm_transition); + commitments.msm_q_add = transcript.template receive_from_prover(commitment_labels.msm_q_add); + commitments.msm_q_double = transcript.template receive_from_prover(commitment_labels.msm_q_double); + commitments.msm_q_skew = transcript.template receive_from_prover(commitment_labels.msm_q_skew); + commitments.msm_accumulator_x = + transcript.template receive_from_prover(commitment_labels.msm_accumulator_x); + commitments.msm_accumulator_y = + transcript.template receive_from_prover(commitment_labels.msm_accumulator_y); + commitments.msm_pc = transcript.template receive_from_prover(commitment_labels.msm_pc); + commitments.msm_size_of_msm = + transcript.template receive_from_prover(commitment_labels.msm_size_of_msm); + commitments.msm_count = transcript.template receive_from_prover(commitment_labels.msm_count); + commitments.msm_round = transcript.template receive_from_prover(commitment_labels.msm_round); + commitments.msm_q_add1 = transcript.template receive_from_prover(commitment_labels.msm_q_add1); + commitments.msm_q_add2 = transcript.template receive_from_prover(commitment_labels.msm_q_add2); + commitments.msm_q_add3 = transcript.template receive_from_prover(commitment_labels.msm_q_add3); + commitments.msm_q_add4 = transcript.template receive_from_prover(commitment_labels.msm_q_add4); + commitments.msm_x1 = transcript.template receive_from_prover(commitment_labels.msm_x1); + commitments.msm_y1 = transcript.template receive_from_prover(commitment_labels.msm_y1); + commitments.msm_x2 = transcript.template receive_from_prover(commitment_labels.msm_x2); + commitments.msm_y2 = transcript.template receive_from_prover(commitment_labels.msm_y2); + commitments.msm_x3 = transcript.template receive_from_prover(commitment_labels.msm_x3); + commitments.msm_y3 = transcript.template receive_from_prover(commitment_labels.msm_y3); + commitments.msm_x4 = transcript.template receive_from_prover(commitment_labels.msm_x4); + commitments.msm_y4 = transcript.template receive_from_prover(commitment_labels.msm_y4); + commitments.msm_collision_x1 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x1); + commitments.msm_collision_x2 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x2); + commitments.msm_collision_x3 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x3); + commitments.msm_collision_x4 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x4); + commitments.msm_lambda1 = transcript.template receive_from_prover(commitment_labels.msm_lambda1); + commitments.msm_lambda2 = transcript.template receive_from_prover(commitment_labels.msm_lambda2); + commitments.msm_lambda3 = transcript.template receive_from_prover(commitment_labels.msm_lambda3); + commitments.msm_lambda4 = transcript.template receive_from_prover(commitment_labels.msm_lambda4); + commitments.msm_slice1 = transcript.template receive_from_prover(commitment_labels.msm_slice1); + commitments.msm_slice2 = transcript.template receive_from_prover(commitment_labels.msm_slice2); + commitments.msm_slice3 = transcript.template receive_from_prover(commitment_labels.msm_slice3); + commitments.msm_slice4 = transcript.template receive_from_prover(commitment_labels.msm_slice4); + commitments.transcript_accumulator_empty = + transcript.template receive_from_prover(commitment_labels.transcript_accumulator_empty); + commitments.transcript_q_reset_accumulator = + transcript.template receive_from_prover(commitment_labels.transcript_q_reset_accumulator); + commitments.q_wnaf = transcript.template receive_from_prover(commitment_labels.q_wnaf); + commitments.lookup_read_counts_0 = + transcript.template receive_from_prover(commitment_labels.lookup_read_counts_0); + commitments.lookup_read_counts_1 = + transcript.template receive_from_prover(commitment_labels.lookup_read_counts_1); + + // Get challenge for sorted list batching and wire four memory records + auto [eta, gamma] = transcript.get_challenges("beta", "gamma"); + relation_parameters.gamma = gamma; + auto eta_sqr = eta * eta; + relation_parameters.eta = eta; + relation_parameters.eta_sqr = eta_sqr; + relation_parameters.eta_cube = eta_sqr * eta; + relation_parameters.permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + relation_parameters.permutation_offset = relation_parameters.permutation_offset.invert(); + + // Get commitment to permutation and lookup grand products + commitments.lookup_inverses = + transcript.template receive_from_prover(commitment_labels.lookup_inverses); + commitments.z_perm = transcript.template receive_from_prover(commitment_labels.z_perm); + + // Execute Sumcheck Verifier + auto sumcheck = Sumcheck>(circuit_size, transcript); + + std::optional sumcheck_output = sumcheck.execute_verifier(relation_parameters); + + // If Sumcheck does not return an output, sumcheck verification has failed + if (!sumcheck_output.has_value()) { + return false; + } + + auto [multivariate_challenge, purported_evaluations] = *sumcheck_output; + + // Execute Gemini/Shplonk verification: + + // Construct inputs for Gemini verifier: + // - Multivariate opening point u = (u_0, ..., u_{d-1}) + // - batched unshifted and to-be-shifted polynomial commitments + auto batched_commitment_unshifted = GroupElement::zero(); + auto batched_commitment_to_be_shifted = GroupElement::zero(); + + // Compute powers of batching challenge rho + FF rho = transcript.get_challenge("rho"); + std::vector rhos = Gemini::powers_of_rho(rho, Flavor::NUM_ALL_ENTITIES); + + // Compute batched multivariate evaluation + FF batched_evaluation = FF::zero(); + size_t evaluation_idx = 0; + for (auto& value : purported_evaluations.get_unshifted()) { + batched_evaluation += value * rhos[evaluation_idx]; + ++evaluation_idx; + } + for (auto& value : purported_evaluations.get_shifted()) { + batched_evaluation += value * rhos[evaluation_idx]; + ++evaluation_idx; + } + + // Construct batched commitment for NON-shifted polynomials + size_t commitment_idx = 0; + for (auto& commitment : commitments.get_unshifted()) { + // very lazy point at infinity check. not complete. fix. + if (commitment.y != 0) { + batched_commitment_unshifted += commitment * rhos[commitment_idx]; + } else { + std::cout << "point at infinity (unshifted)" << std::endl; + } + ++commitment_idx; + } + + // Construct batched commitment for to-be-shifted polynomials + for (auto& commitment : commitments.get_to_be_shifted()) { + // very lazy point at infinity check. not complete. fix. + if (commitment.y != 0) { + batched_commitment_to_be_shifted += commitment * rhos[commitment_idx]; + } else { + std::cout << "point at infinity (to be shifted)" << std::endl; + } + ++commitment_idx; + } + + // Produce a Gemini claim consisting of: + // - d+1 commitments [Fold_{r}^(0)], [Fold_{-r}^(0)], and [Fold^(l)], l = 1:d-1 + // - d+1 evaluations a_0_pos, and a_l, l = 0:d-1 + auto gemini_claim = Gemini::reduce_verify(multivariate_challenge, + batched_evaluation, + batched_commitment_unshifted, + batched_commitment_to_be_shifted, + transcript); + + // Produce a Shplonk claim: commitment [Q] - [Q_z], evaluation zero (at random challenge z) + auto shplonk_claim = Shplonk::reduce_verify(gemini_claim, transcript); + + // // Verify the Shplonk claim with KZG or IPA + return PCS::verify(pcs_verification_key, shplonk_claim, transcript); +} + +template class ECCVMVerifier_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.hpp b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.hpp new file mode 100644 index 0000000000..3c83f6dddd --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.hpp @@ -0,0 +1,47 @@ +#pragma once +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/plonk/proof_system/types/proof.hpp" +#include "barretenberg/honk/sumcheck/sumcheck.hpp" + +namespace proof_system::honk { +template class ECCVMVerifier_ { + using FF = typename Flavor::FF; + using Commitment = typename Flavor::Commitment; + using VerificationKey = typename Flavor::VerificationKey; + using PCSVerificationKey = typename Flavor::PCSParams::VerificationKey; + + public: + explicit ECCVMVerifier_(std::shared_ptr verifier_key = nullptr); + ECCVMVerifier_(std::shared_ptr key, + std::map commitments, + std::map pcs_fr_elements, + std::shared_ptr pcs_verification_key, + VerifierTranscript transcript) + : key(std::move(key)) + , commitments(std::move(commitments)) + , pcs_fr_elements(std::move(pcs_fr_elements)) + , pcs_verification_key(std::move(pcs_verification_key)) + , transcript(std::move(transcript)) + {} + ECCVMVerifier_(ECCVMVerifier_&& other) noexcept; + ECCVMVerifier_(const ECCVMVerifier_& other) = delete; + ECCVMVerifier_& operator=(const ECCVMVerifier_& other) = delete; + ECCVMVerifier_& operator=(ECCVMVerifier_&& other) noexcept; + ~ECCVMVerifier_() = default; + + bool verify_proof(const plonk::proof& proof); + + std::shared_ptr key; + std::map commitments; + std::map pcs_fr_elements; + std::shared_ptr pcs_verification_key; + VerifierTranscript transcript; +}; + +extern template class ECCVMVerifier_; +extern template class ECCVMVerifier_; + +using ECCVMVerifier = ECCVMVerifier_; +using ECCVMVerifierGrumpkin = ECCVMVerifier_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/lookup_library.hpp b/cpp/src/barretenberg/honk/proof_system/lookup_library.hpp new file mode 100644 index 0000000000..65f36427d7 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/lookup_library.hpp @@ -0,0 +1,64 @@ +#pragma once +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include + +namespace proof_system::honk::lookup_library { + +/** + * @brief Compute the inverse polynomial I(X) required for logderivative lookups + * * + * @details + * Inverse may be defined in terms of its values on X_i = 0,1,...,n-1 as Z_perm[0] = 1 and for i = 1:n-1 + * 1 1 + * Inverse[i] = ∏ -------------------------- * ∏' -------------------------- + * relation::read_term(j) relation::write_term(j) + * + * where ∏ := ∏_{j=0:relation::NUM_READ_TERMS-1} and ∏' := ∏'_{j=0:relation::NUM_WRITE_TERMS-1} + * + * If row [i] does not contain a lookup read gate or a write gate, Inverse[i] = 0 + * N.B. by "write gate" we mean; do the lookup table polynomials contain nonzero values at this row? + * (in the ECCVM, the lookup table is not precomputed, so we have a concept of a "write gate", unlike when precomputed + * lookup tables are used) + * + * The specific algebraic relations that define read terms and write terms are defined in Flavor::LookupRelation + * + */ +template +void compute_logderivative_inverse(auto& polynomials, + sumcheck::RelationParameters& relation_parameters, + const size_t circuit_size) +{ + using FF = typename Flavor::FF; + using Accumulator = typename Relation::ValueAccumTypes; + constexpr size_t READ_TERMS = Relation::READ_TERMS; + constexpr size_t WRITE_TERMS = Relation::WRITE_TERMS; + auto& inverse_polynomial = polynomials.lookup_inverses; + // auto& inverse_polynomial = key->lookup_inverses; + // const size_t circuit_size = key->circuit_size; + + auto lookup_relation = Relation(); + for (size_t i = 0; i < circuit_size; ++i) { + bool has_inverse = + lookup_relation.template lookup_exists_at_row_index(polynomials, relation_parameters, i); + if (!has_inverse) { + continue; + } + FF denominator = 1; + barretenberg::constexpr_for<0, READ_TERMS, 1>([&] { + auto denominator_term = lookup_relation.template compute_read_term( + polynomials, relation_parameters, i); + denominator *= denominator_term; + }); + barretenberg::constexpr_for<0, WRITE_TERMS, 1>([&] { + auto denominator_term = lookup_relation.template compute_write_term( + polynomials, relation_parameters, i); + denominator *= denominator_term; + }); + inverse_polynomial[i] = denominator; + }; + + // todo might be inverting zero in field bleh bleh + FF::batch_invert(inverse_polynomial); +} + +} // namespace proof_system::honk::lookup_library \ No newline at end of file diff --git a/cpp/src/barretenberg/honk/proof_system/permutation_library.hpp b/cpp/src/barretenberg/honk/proof_system/permutation_library.hpp new file mode 100644 index 0000000000..494c7d0975 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/permutation_library.hpp @@ -0,0 +1,165 @@ +#pragma once +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp" +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include + +namespace proof_system::honk::permutation_library { + +/** + * @brief Compute a permutation grand product polynomial Z_perm(X) + * * + * @details + * Z_perm may be defined in terms of its values on X_i = 0,1,...,n-1 as Z_perm[0] = 1 and for i = 1:n-1 + * relation::numerator(j) + * Z_perm[i] = ∏ -------------------------------------------------------------------------------- + * relation::denominator(j) + * + * where ∏ := ∏_{j=0:i-1} + * + * The specific algebraic relation used by Z_perm is defined by Flavor::GrandProductRelations + * + * For example, in Flavor::Standard the relation describes: + * + * (w_1(j) + β⋅id_1(j) + γ) ⋅ (w_2(j) + β⋅id_2(j) + γ) ⋅ (w_3(j) + β⋅id_3(j) + γ) + * Z_perm[i] = ∏ -------------------------------------------------------------------------------- + * (w_1(j) + β⋅σ_1(j) + γ) ⋅ (w_2(j) + β⋅σ_2(j) + γ) ⋅ (w_3(j) + β⋅σ_3(j) + γ) + * where ∏ := ∏_{j=0:i-1} and id_i(X) = id(X) + n*(i-1) + * + * For Flavor::Ultra both the UltraPermutation and Lookup grand products are computed by this method. + * + * The grand product is constructed over the course of three steps. + * + * For expositional simplicity, write Z_perm[i] as + * + * A(j) + * Z_perm[i] = ∏ -------------------------- + * B(h) + * + * Step 1) Compute 2 length-n polynomials A, B + * Step 2) Compute 2 length-n polynomials numerator = ∏ A(j), nenominator = ∏ B(j) + * Step 3) Compute Z_perm[i + 1] = numerator[i] / denominator[i] (recall: Z_perm[0] = 1) + * + * Note: Step (3) utilizes Montgomery batch inversion to replace n-many inversions with + */ +template +void compute_permutation_grand_product(const size_t circuit_size, + auto& full_polynomials, + sumcheck::RelationParameters& relation_parameters) +{ + using FF = typename Flavor::FF; + using Polynomial = typename Flavor::Polynomial; + using ValueAccumTypes = PermutationRelation::ValueAccumTypes; + + // Allocate numerator/denominator polynomials that will serve as scratch space + // TODO(zac) we can re-use the permutation polynomial as the numerator polynomial. Reduces readability + Polynomial numerator = Polynomial{ circuit_size }; + Polynomial denominator = Polynomial{ circuit_size }; + + // Step (1) + // Populate `numerator` and `denominator` with the algebra described by PermutationRelation + const size_t num_threads = circuit_size >= get_num_cpus_pow2() ? get_num_cpus_pow2() : 1; + const size_t block_size = circuit_size / num_threads; + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx + 1) * block_size; + for (size_t i = start; i < end; ++i) { + + typename Flavor::ClaimedEvaluations evaluations; + for (size_t k = 0; k < Flavor::NUM_ALL_ENTITIES; ++k) { + evaluations[k] = full_polynomials[k].size() > i ? full_polynomials[k][i] : 0; + } + numerator[i] = PermutationRelation::template compute_permutation_numerator( + evaluations, relation_parameters, i); + denominator[i] = PermutationRelation::template compute_permutation_denominator( + evaluations, relation_parameters, i); + } + }); + + // Step (2) + // Compute the accumulating product of the numerator and denominator terms. + // This step is split into three parts for efficient multithreading: + // (i) compute ∏ A(j), ∏ B(j) subproducts for each thread + // (ii) compute scaling factor required to convert each subproduct into a single running product + // (ii) combine subproducts into a single running product + // + // For example, consider 4 threads and a size-8 numerator { a0, a1, a2, a3, a4, a5, a6, a7 } + // (i) Each thread computes 1 element of N = {{ a0, a0a1 }, { a2, a2a3 }, { a4, a4a5 }, { a6, a6a7 }} + // (ii) Take partial products P = { 1, a0a1, a2a3, a4a5 } + // (iii) Each thread j computes N[i][j]*P[j]= + // {{a0,a0a1},{a0a1a2,a0a1a2a3},{a0a1a2a3a4,a0a1a2a3a4a5},{a0a1a2a3a4a5a6,a0a1a2a3a4a5a6a7}} + std::vector partial_numerators(num_threads); + std::vector partial_denominators(num_threads); + + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx + 1) * block_size; + for (size_t i = start; i < end - 1; ++i) { + numerator[i + 1] *= numerator[i]; + denominator[i + 1] *= denominator[i]; + } + partial_numerators[thread_idx] = numerator[end - 1]; + partial_denominators[thread_idx] = denominator[end - 1]; + }); + + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx + 1) * block_size; + if (thread_idx > 0) { + FF numerator_scaling = 1; + FF denominator_scaling = 1; + + for (size_t j = 0; j < thread_idx; ++j) { + numerator_scaling *= partial_numerators[j]; + denominator_scaling *= partial_denominators[j]; + } + for (size_t i = start; i < end; ++i) { + numerator[i] *= numerator_scaling; + denominator[i] *= denominator_scaling; + } + } + + // Final step: invert denominator + FF::batch_invert(std::span{ &denominator[start], block_size }); + }); + + // Step (3) Compute z_perm[i] = numerator[i] / denominator[i] + auto& grand_product_polynomial = PermutationRelation::get_grand_product_polynomial(full_polynomials); + grand_product_polynomial[0] = 0; + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx == num_threads - 1) ? circuit_size - 1 : (thread_idx + 1) * block_size; + for (size_t i = start; i < end; ++i) { + grand_product_polynomial[i + 1] = numerator[i] * denominator[i]; + } + }); +} + +template +void compute_permutation_grand_products(std::shared_ptr& key, + typename Flavor::ProverPolynomials& full_polynomials, + sumcheck::RelationParameters& relation_parameters) +{ + using GrandProductRelations = typename Flavor::GrandProductRelations; + using FF = typename Flavor::FF; + + constexpr size_t NUM_RELATIONS = std::tuple_size{}; + barretenberg::constexpr_for<0, NUM_RELATIONS, 1>([&]() { + using PermutationRelation = std::tuple_element::type; + + // Assign the grand product polynomial to the relevant std::span member of `full_polynomials` (and its shift) + // For example, for UltraPermutationRelation, this will be `full_polynomials.z_perm` + // For example, for LookupRelation, this will be `full_polynomials.z_lookup` + std::span& full_polynomial = PermutationRelation::get_grand_product_polynomial(full_polynomials); + auto& key_polynomial = PermutationRelation::get_grand_product_polynomial(*key); + full_polynomial = key_polynomial; + + compute_permutation_grand_product( + key->circuit_size, full_polynomials, relation_parameters); + std::span& full_polynomial_shift = + PermutationRelation::get_shifted_grand_product_polynomial(full_polynomials); + full_polynomial_shift = key_polynomial.shifted(); + }); +} + +} // namespace proof_system::honk::permutation_library \ No newline at end of file diff --git a/cpp/src/barretenberg/honk/proof_system/prover_library.hpp b/cpp/src/barretenberg/honk/proof_system/prover_library.hpp index 9990400c74..7db147a3e4 100644 --- a/cpp/src/barretenberg/honk/proof_system/prover_library.hpp +++ b/cpp/src/barretenberg/honk/proof_system/prover_library.hpp @@ -1,10 +1,19 @@ #pragma once +#include "barretenberg/common/constexpr_utils.hpp" #include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" #include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp" #include "barretenberg/plonk/proof_system/types/program_settings.hpp" #include "barretenberg/plonk/proof_system/types/proof.hpp" -#include "barretenberg/polynomials/polynomial.hpp" +// TODO(@zac-williamson). We used to include `program_settings.hpp` in this file. Needed to remove due to circular +// dependency. `program_settings.hpp` included header files that added "using namespace proof_system" and "using +// namespace barretenberg" declarations. This effects downstream code that relies on these using declarations. This is a +// big code smell (should really not have using declarations in header files!), however fixing it requires changes in a +// LOT of files. This would clutter the eccvm feature PR. Adding these following "using namespace" declarations is a +// temp workaround. Once this work is merged in we should fix the root problem (no using declarations in header files) +using namespace proof_system; +using namespace barretenberg; namespace proof_system::honk::prover_library { template diff --git a/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp b/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp index 948382b845..5765f74cb5 100644 --- a/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp +++ b/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp @@ -96,6 +96,15 @@ template class Univariate { res -= other; return res; } + Univariate operator-() const + { + Univariate res(*this); + for (auto& eval : res.evaluations) { + eval = -eval; + } + return res; + } + Univariate operator*(const Univariate& other) const { Univariate res(*this); @@ -249,6 +258,15 @@ template class UnivariateView { return res; } + Univariate operator-() const + { + Univariate res(*this); + for (auto& eval : res.evaluations) { + eval = -eval; + } + return res; + } + Univariate operator*(const UnivariateView& other) const { Univariate res(*this); diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/ecc_vm/ecc_vm_relation.test.cpp b/cpp/src/barretenberg/honk/sumcheck/relations/ecc_vm/ecc_vm_relation.test.cpp new file mode 100644 index 0000000000..1f0de00cca --- /dev/null +++ b/cpp/src/barretenberg/honk/sumcheck/relations/ecc_vm/ecc_vm_relation.test.cpp @@ -0,0 +1,362 @@ +#include "barretenberg/honk/composer/composer/eccvm_composer.hpp" +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/proof_system/lookup_library.hpp" +#include "barretenberg/honk/proof_system/permutation_library.hpp" +#include "barretenberg/honk/proof_system/prover_library.hpp" +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include "barretenberg/numeric/random/engine.hpp" +#include "barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp" +#include + +/** + * We want to test if all three relations (namely, ArithmeticRelation, GrandProductComputationRelation, + * GrandProductInitializationRelation) provide correct contributions by manually computing their + * contributions with deterministic and random inputs. The relations are supposed to work with + * univariates (edges) of degree one (length 2) and spit out polynomials of corresponding degrees. We have + * MAX_RELATION_LENGTH = 5, meaning the output of a relation can atmost be a degree 5 polynomial. Hence, + * we use a method compute_mock_extended_edges() which starts with degree one input polynomial (two evaluation + points), + * extends them (using barycentric formula) to six evaluation points, and stores them to an array of polynomials. + */ + +using namespace proof_system::honk::sumcheck; +using Flavor = proof_system::honk::flavor::ECCVM; +using FF = typename Flavor::FF; +using ProverPolynomials = typename Flavor::ProverPolynomials; +using RawPolynomials = typename Flavor::RawPolynomials; + +static constexpr size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES; + +namespace proof_system::honk_relation_tests_ecc_vm_full { + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +static grumpkin::g1::element a; +static grumpkin::g1::element b; +static grumpkin::g1::element c; +static grumpkin::fr x; +static grumpkin::fr y; +static bool init = false; + +ECCVMCircuitConstructor generate_trace(numeric::random::Engine* engine = nullptr) +{ + ECCVMCircuitConstructor result; + if (!init) { + a = grumpkin::get_generator(0); + b = grumpkin::get_generator(1); + c = grumpkin::get_generator(2); + x = grumpkin::fr::random_element(engine); + y = grumpkin::fr::random_element(engine); + init = true; + } + grumpkin::g1::element expected_1 = (a * x) + a + (b * x) + (b * x) + (b * x); + grumpkin::g1::element expected_2 = (a * x) + c + (b * x); + + result.mul_accumulate(a, x); + + return result; +} + +TEST(SumcheckRelation, ECCVMLookupRelationAlgebra) +{ + const auto run_test = []() { + auto lookup_relation = ECCVMLookupRelation(); + + barretenberg::fr scaling_factor = barretenberg::fr::random_element(); + const FF gamma = FF::random_element(&engine); + const FF eta = FF::random_element(&engine); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = 1, + .gamma = gamma, + .public_input_delta = 1, + .lookup_grand_product_delta = 1, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + + auto circuit_constructor = generate_trace(&engine); + auto rows = circuit_constructor.compute_full_polynomials(); + const size_t num_rows = rows[0].size(); + honk::lookup_library::compute_logderivative_inverse>( + rows, relation_params, num_rows); + honk::permutation_library::compute_permutation_grand_product>( + num_rows, rows, relation_params); + rows.z_perm_shift = Flavor::Polynomial(rows.z_perm.shifted()); + + // auto transcript_trace = transcript_trace.export_rows(); + + ECCVMLookupRelation::RelationValues result; + for (auto& r : result) { + r = 0; + } + for (size_t i = 0; i < num_rows; ++i) { + Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + lookup_relation.add_full_relation_value_contribution(result, row, relation_params, scaling_factor); + } + + for (auto r : result) { + EXPECT_EQ(r, 0); + } + }; + run_test(); +} + +TEST(SumcheckRelation, ECCVMFullRelationAlgebra) +{ + const auto run_test = []() { + // auto transcript_relation = ECCVMTranscriptRelation(); + // auto point_relation = ECCVMPointTableRelation(); + // auto wnaf_relation = ECCVMWnafRelation(); + // auto msm_relation = ECCVMMSMRelation(); + // auto set_relation = ECCVMSetRelation(); + auto lookup_relation = ECCVMLookupRelation(); + + barretenberg::fr scaling_factor = barretenberg::fr::random_element(); + const FF gamma = FF::random_element(&engine); + const FF eta = FF::random_element(&engine); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = 1, + .gamma = gamma, + .public_input_delta = 1, + .lookup_grand_product_delta = 1, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + auto circuit_constructor = generate_trace(&engine); + auto rows = circuit_constructor.compute_full_polynomials(); + const size_t num_rows = rows[0].size(); + honk::lookup_library::compute_logderivative_inverse>( + rows, relation_params, num_rows); + honk::permutation_library::compute_permutation_grand_product>( + num_rows, rows, relation_params); + rows.z_perm_shift = Flavor::Polynomial(rows.z_perm.shifted()); + + // compute_permutation_polynomials(rows, relation_params); + // compute_lookup_inverse_polynomial(rows, relation_params); + + // auto transcript_trace = transcript_trace.export_rows(); + + ECCVMLookupRelation::RelationValues lookup_result; + for (auto& r : lookup_result) { + r = 0; + } + + const auto evaluate_relation = [&](std::string relation_name) { + auto relation = Relation(); + typename Relation::RelationValues result; + for (auto& r : result) { + r = 0; + } + constexpr size_t NUM_SUBRELATIONS = result.size(); + std::array relation_fail{}; + std::array relation_fails_at_row{}; + + for (size_t i = 0; i < num_rows; ++i) { + Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + relation.add_full_relation_value_contribution(result, row, relation_params, scaling_factor); + + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + if (result[j] != 0) { + if (!relation_fail[j]) { + relation_fail[j] = true; + relation_fails_at_row[j] = i; + } + } + } + } + + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + EXPECT_EQ(relation_fail[j], false); + if (relation_fail[j]) { + std::cerr << "relation " << relation_name << ", subrelation " << j + << " fails. First failure at row " << relation_fails_at_row[j] << std::endl; + } + } + }; + + evaluate_relation.template operator()>("ECCVMTranscriptRelation"); + evaluate_relation.template operator()>("ECCVMPointTableRelation"); + evaluate_relation.template operator()>("ECCVMWnafRelation"); + evaluate_relation.template operator()>("ECCVMMSMRelation"); + evaluate_relation.template operator()>("ECCVMSetRelation"); + + for (size_t i = 0; i < num_rows; ++i) { + Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + { + lookup_relation.add_full_relation_value_contribution( + lookup_result, row, relation_params, scaling_factor); + } + } + for (auto r : lookup_result) { + EXPECT_EQ(r, 0); + } + }; + run_test(); +} + +TEST(SumcheckRelation, ECCVMFullRelationProver) +{ + const auto run_test = []() { + const FF gamma = FF::random_element(&engine); + const FF eta = FF::random_element(&engine); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = 1, + .gamma = gamma, + .public_input_delta = 1, + .lookup_grand_product_delta = 1, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + + auto circuit_constructor = generate_trace(&engine); + auto full_polynomials = circuit_constructor.compute_full_polynomials(); + const size_t num_rows = full_polynomials[0].size(); + honk::lookup_library::compute_logderivative_inverse>( + full_polynomials, relation_params, num_rows); + + honk::permutation_library::compute_permutation_grand_product>( + num_rows, full_polynomials, relation_params); + full_polynomials.z_perm_shift = Flavor::Polynomial(full_polynomials.z_perm.shifted()); + + // size_t pidx = 0; + // for (auto& p : full_polynomials) { + // size_t count = 0; + // for (auto& x : p) { + // std::cout << "poly[" << pidx << "][" << count << "] = " << x << std::endl; + // count++; + // } + // pidx++; + // } + // auto foo = full_polynomials.get_to_be_shifted(); + // size_t c = 0; + // for (auto& x : foo) { + // if (x[0] != 0) { + // std::cout << "shift at " << c << "not zero :/" << std::endl; + // } + // c += 1; + // } + const size_t multivariate_n = full_polynomials[0].size(); + const size_t multivariate_d = numeric::get_msb64(multivariate_n); + + EXPECT_EQ(1ULL << multivariate_d, multivariate_n); + + auto prover_transcript = honk::ProverTranscript::init_empty(); + + auto sumcheck_prover = Sumcheck>(multivariate_n, prover_transcript); + + auto prover_output = sumcheck_prover.execute_prover(full_polynomials, relation_params); + + auto verifier_transcript = honk::VerifierTranscript::init_empty(prover_transcript); + + auto sumcheck_verifier = Sumcheck>(multivariate_n, verifier_transcript); + + std::optional verifier_output = sumcheck_verifier.execute_verifier(relation_params); + + ASSERT_TRUE(verifier_output.has_value()); + ASSERT_EQ(prover_output, *verifier_output); + }; + run_test(); +} + +class ECCVMComposerTestsB : public ::testing::Test { + protected: + static void SetUpTestSuite() { barretenberg::srs::init_crs_factory("../srs_db/ignition"); } +}; +TEST_F(ECCVMComposerTestsB, BaseCase) +{ + auto circuit_constructor = generate_trace(&engine); + // auto composer = honk::ECCVMComposerHelper(); + // auto prover = composer.create_prover(circuit_constructor); + + // prover.construct_proof(); + // auto eta = prover.relation_parameters.eta; + // auto beta = prover.relation_parameters.beta; + // auto gamma = prover.relation_parameters.gamma; + // ECCVMBuilder trace2 = generate_trace(&engine); + + auto eta = FF::random_element(&engine); // prover.relation_parameters.eta; + auto beta = FF::random_element(&engine); // prover.relation_parameters.beta; + auto gamma = FF::random_element(&engine); // prover.relation_parameters.gamma; + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = beta, + .gamma = gamma, + .public_input_delta = 0, + .lookup_grand_product_delta = 0, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + // std::cout << "gamma eta = " << gamma << " , " << eta << std::endl; + + // RawPolynomials full_polynomials = trace2.compute_full_polynomials(); + + // auto& full_polynomials = prover.prover_polynomials; + auto full_polynomials = circuit_constructor.compute_full_polynomials(); + // compute_logderivative_inverse(prover.proving_key, full_polynomials) + const size_t multivariate_n = full_polynomials[0].size(); + const size_t multivariate_d = numeric::get_msb64(multivariate_n); + + EXPECT_EQ(1ULL << multivariate_d, multivariate_n); + + honk::lookup_library::compute_logderivative_inverse>( + full_polynomials, relation_params, multivariate_n); + + honk::permutation_library::compute_permutation_grand_product>( + multivariate_n, full_polynomials, relation_params); + full_polynomials.z_perm_shift = Flavor::Polynomial(full_polynomials.z_perm.shifted()); + + auto prover_transcript = honk::ProverTranscript::init_empty(); + + auto sumcheck_prover = Sumcheck>(multivariate_n, prover_transcript); + + auto prover_output = sumcheck_prover.execute_prover(full_polynomials, relation_params); + + auto verifier_transcript = honk::VerifierTranscript::init_empty(prover_transcript); + + auto sumcheck_verifier = Sumcheck>(multivariate_n, verifier_transcript); + + std::optional verifier_output = sumcheck_verifier.execute_verifier(relation_params); + + ASSERT_TRUE(verifier_output.has_value()); + ASSERT_EQ(prover_output, *verifier_output); +} +} // namespace proof_system::honk_relation_tests_ecc_vm_full diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp b/cpp/src/barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp new file mode 100644 index 0000000000..53709d5e9b --- /dev/null +++ b/cpp/src/barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include "relation_types.hpp" + +#define ExtendedEdge(Flavor) Flavor::ExtendedEdges +#define EvaluationEdge(Flavor) Flavor::ClaimedEvaluations +#define EntityEdge(Flavor) Flavor::AllEntities + +#define ADD_EDGE_CONTRIBUTION(...) _ADD_EDGE_CONTRIBUTION(__VA_ARGS__) +#define _ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, AccumulatorType, EdgeType) \ + Preface template void \ + Relation::add_edge_contribution_impl::AccumulatorType, \ + EdgeType(Flavor)>( \ + RelationWrapper::AccumulatorType::Accumulators&, \ + EdgeType(Flavor) const&, \ + RelationParameters const&, \ + Flavor::FF const&) const; + +#define PERMUTATION_METHOD(...) _PERMUTATION_METHOD(__VA_ARGS__) +#define _PERMUTATION_METHOD(Preface, MethodName, Relation, Flavor, AccumulatorType, EdgeType) \ + Preface template Relation::template Accumulator< \ + RelationWrapper::AccumulatorType> \ + Relation::MethodName::AccumulatorType, EdgeType(Flavor)>( \ + EdgeType(Flavor) const&, RelationParameters const&, size_t const); + +#define SUMCHECK_RELATION_CLASS(...) _SUMCHECK_RELATION_CLASS(__VA_ARGS__) +#define _SUMCHECK_RELATION_CLASS(Preface, Relation, Flavor) \ + ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, UnivariateAccumTypes, ExtendedEdge) \ + ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, ValueAccumTypes, EvaluationEdge) \ + ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, ValueAccumTypes, EntityEdge) + +#define DECLARE_SUMCHECK_RELATION_CLASS(Relation, Flavor) SUMCHECK_RELATION_CLASS(extern, Relation, Flavor) +#define DEFINE_SUMCHECK_RELATION_CLASS(Relation, Flavor) SUMCHECK_RELATION_CLASS(, Relation, Flavor) + +#define SUMCHECK_PERMUTATION_CLASS(...) _SUMCHECK_PERMUTATION_CLASS(__VA_ARGS__) +#define _SUMCHECK_PERMUTATION_CLASS(Preface, Relation, Flavor) \ + PERMUTATION_METHOD(Preface, compute_permutation_numerator, Relation, Flavor, UnivariateAccumTypes, ExtendedEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_numerator, Relation, Flavor, ValueAccumTypes, EvaluationEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_numerator, Relation, Flavor, ValueAccumTypes, EntityEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_denominator, Relation, Flavor, UnivariateAccumTypes, ExtendedEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_denominator, Relation, Flavor, ValueAccumTypes, EvaluationEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_denominator, Relation, Flavor, ValueAccumTypes, EntityEdge) + +#define DECLARE_SUMCHECK_PERMUTATION_CLASS(Relation, Flavor) SUMCHECK_PERMUTATION_CLASS(extern, Relation, Flavor) +#define DEFINE_SUMCHECK_PERMUTATION_CLASS(Relation, Flavor) SUMCHECK_PERMUTATION_CLASS(, Relation, Flavor) diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp b/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp index 863688560c..ccc74b3391 100644 --- a/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp +++ b/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp @@ -14,5 +14,8 @@ template struct RelationParameters { FF gamma = FF::zero(); // Permutation + Lookup FF public_input_delta = FF::zero(); // Permutation FF lookup_grand_product_delta = FF::zero(); // Lookup + FF eta_sqr = FF::zero(); + FF eta_cube = FF::zero(); + FF permutation_offset = FF::zero(); // TODO(@zac-williamson) explain what this is (to do w. set equality check) }; } // namespace proof_system::honk::sumcheck diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp b/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp index a4842eef7c..b370f8f835 100644 --- a/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp +++ b/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp @@ -1,13 +1,25 @@ #pragma once #include #include - -#include "../polynomials/univariate.hpp" +#include #include "relation_parameters.hpp" +// forward-declare Polynomial so we can use in a concept +namespace barretenberg { +template class Polynomial; +} namespace proof_system::honk::sumcheck { -template -concept HasSubrelationLinearlyIndependentMember = requires(T) { T::Relation::SUBRELATION_LINEARLY_INDEPENDENT; }; + +template class Univariate; +template class UnivariateView; + +template +concept HasSubrelationLinearlyIndependentMember = requires(T) { + { + std::get(T::SUBRELATION_LINEARLY_INDEPENDENT) + } -> std::convertible_to; +}; + /** * @brief The templates defined herein facilitate sharing the relation arithmetic between the prover and the verifier. * @@ -27,7 +39,7 @@ concept HasSubrelationLinearlyIndependentMember = requires(T) { T::Relation::SUB * @brief Getter method that will return `input[index]` iff `input` is a std::span container * * @tparam FF - * @tparam TypeMuncher + * @tparam AccumulatorTypes * @tparam T * @param input * @param index @@ -42,7 +54,25 @@ inline typename std::tuple_element<0, typename AccumulatorTypes::AccumulatorView } /** - * @brief Getter method that will return `input[index]` iff `input` is not a std::span container + * @brief Getter method that will return `input[index]` iff `input` is a Polynomial container + * + * @tparam FF + * @tparam TypeMuncher + * @tparam T + * @param input + * @param index + * @return requires + */ +template + requires std::is_same, T>::value +inline std::tuple_element<0, typename AccumulatorTypes::AccumulatorViews>::type get_view(const T& input, + const size_t index) +{ + return input[index]; +} + +/** + * @brief Getter method that will return `input[index]` iff `input` is not a std::span or a Polynomial container * * @tparam FF * @tparam TypeMuncher @@ -58,6 +88,7 @@ inline typename std::tuple_element<0, typename AccumulatorTypes::AccumulatorView return typename std::tuple_element<0, typename AccumulatorTypes::AccumulatorViews>::type(input); } + /** * @brief A wrapper for Relations to expose methods used by the Sumcheck prover or verifier to add the contribution of * a given relation to the corresponding accumulator. @@ -102,30 +133,29 @@ template typename RelationBase> class Relation Relation::template add_edge_contribution_impl( accumulator, input, relation_parameters, scaling_factor); } - /** * @brief Check is subrelation is linearly independent - * Method always returns true if relation has no SUBRELATION_LINEARLY_INDEPENDENT std::array - * (i.e. default is to make linearly independent) + * Method is active if relation has SUBRELATION_LINEARLY_INDEPENDENT array defined * @tparam size_t */ - template + template static constexpr bool is_subrelation_linearly_independent() - requires(!HasSubrelationLinearlyIndependentMember) + requires(HasSubrelationLinearlyIndependentMember) { - return true; + return std::get(Relation::SUBRELATION_LINEARLY_INDEPENDENT); } /** * @brief Check is subrelation is linearly independent - * Method is active if relation has SUBRELATION_LINEARLY_INDEPENDENT array defined + * Method always returns true if relation has no SUBRELATION_LINEARLY_INDEPENDENT std::array + * (i.e. default is to make linearly independent) * @tparam size_t */ template static constexpr bool is_subrelation_linearly_independent() - requires(HasSubrelationLinearlyIndependentMember) + requires(!HasSubrelationLinearlyIndependentMember) { - return std::get(Relation::SUBRELATION_LINEARLY_INDEPENDENT); + return true; } }; diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_builder_types.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_builder_types.hpp new file mode 100644 index 0000000000..978d0bed55 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_builder_types.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" + +namespace proof_system_eccvm { + +static constexpr size_t NUM_SCALAR_BITS = 128; +static constexpr size_t WNAF_SLICE_BITS = 4; +static constexpr size_t NUM_WNAF_SLICES = (NUM_SCALAR_BITS + WNAF_SLICE_BITS - 1) / WNAF_SLICE_BITS; +static constexpr uint64_t WNAF_MASK = static_cast((1ULL << WNAF_SLICE_BITS) - 1ULL); +static constexpr size_t POINT_TABLE_SIZE = 1ULL << (WNAF_SLICE_BITS); +static constexpr size_t WNAF_SLICES_PER_ROW = 4; +static constexpr size_t ADDITIONS_PER_ROW = 4; + +template struct VMOperation { + bool add = false; + bool mul = false; + bool eq = false; + bool reset = false; + typename CycleGroup::affine_element base_point = typename CycleGroup::affine_element{ 0, 0 }; + uint256_t z1 = 0; + uint256_t z2 = 0; + typename CycleGroup::subgroup_field mul_scalar_full = 0; +}; +template struct ScalarMul { + uint32_t pc; + uint256_t scalar; + typename CycleGroup::affine_element base_point; + std::array wnaf_slices; + bool wnaf_skew; + std::array precomputed_table; +}; + +template using MSM = std::vector>; + +} // namespace proof_system_eccvm \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp new file mode 100644 index 0000000000..245c489ea8 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp @@ -0,0 +1,489 @@ +#pragma once + +#include "./eccvm_builder_types.hpp" +#include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" + +#include "./msm_builder.hpp" +#include "./transcript_builder.hpp" +#include "./precomputed_tables_builder.hpp" +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/proof_system/lookup_library.hpp" +#include "barretenberg/honk/proof_system/permutation_library.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" + +namespace proof_system { + +template class ECCVMCircuitConstructor { + public: + using CycleGroup = typename Flavor::CycleGroup; + using CycleScalar = typename CycleGroup::subgroup_field; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + static constexpr size_t NUM_SCALAR_BITS = proof_system_eccvm::NUM_SCALAR_BITS; + static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS; + static constexpr size_t NUM_WNAF_SLICES = proof_system_eccvm::NUM_WNAF_SLICES; + static constexpr uint64_t WNAF_MASK = proof_system_eccvm::WNAF_MASK; + static constexpr size_t POINT_TABLE_SIZE = proof_system_eccvm::POINT_TABLE_SIZE; + static constexpr size_t WNAF_SLICES_PER_ROW = proof_system_eccvm::WNAF_SLICES_PER_ROW; + static constexpr size_t ADDITIONS_PER_ROW = proof_system_eccvm::ADDITIONS_PER_ROW; + + static constexpr size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES; + static constexpr size_t NUM_WIRES = Flavor::NUM_WIRES; + + using MSM = proof_system_eccvm::MSM; + using VMOperation = proof_system_eccvm::VMOperation; + std::vector vm_operations; + using ScalarMul = proof_system_eccvm::ScalarMul; + using RawPolynomials = typename Flavor::RawPolynomials; + using Polynomial = barretenberg::Polynomial; + uint32_t get_number_of_muls() + { + uint32_t num_muls = 0; + for (auto& op : vm_operations) { + if (op.mul) { + if (op.z1 != 0) { + num_muls++; + } + if (op.z2 != 0) { + num_muls++; + } + } + } + return num_muls; + } + + std::vector get_msms() + { + const uint32_t num_muls = get_number_of_muls(); + /** + * For input point [P], return { -15[P], -13[P], ..., -[P], [P], ..., 13[P], 15[P] } + */ + const auto compute_precomputed_table = [](const AffineElement& base_point) { + const auto d2 = Element(base_point).dbl(); + std::array table; + table[POINT_TABLE_SIZE / 2] = base_point; + for (size_t i = 1; i < POINT_TABLE_SIZE / 2; ++i) { + table[i + POINT_TABLE_SIZE / 2] = Element(table[i + POINT_TABLE_SIZE / 2 - 1]) + d2; + } + for (size_t i = 0; i < POINT_TABLE_SIZE / 2; ++i) { + table[i] = -table[POINT_TABLE_SIZE - 1 - i]; + } + return table; + }; + const auto compute_wnaf_slices = [](uint256_t scalar) { + std::array output; + int previous_slice = 0; + for (size_t i = 0; i < NUM_WNAF_SLICES; ++i) { + // slice the scalar into 4-bit chunks, starting with the least significant bits + uint64_t raw_slice = static_cast(scalar) & WNAF_MASK; + + bool is_even = ((raw_slice & 1ULL) == 0ULL); + + int wnaf_slice = static_cast(raw_slice); + + if (i == 0 && is_even) { + // if least significant slice is even, we add 1 to create an odd value && set 'skew' to true + wnaf_slice += 1; + } else if (is_even) { + // for other slices, if it's even, we add 1 to the slice value + // and subtract 16 from the previous slice to preserve the total scalar sum + static constexpr int borrow_constant = static_cast(1ULL << WNAF_SLICE_BITS); + previous_slice -= borrow_constant; + wnaf_slice += 1; + } + + if (i > 0) { + const size_t idx = i - 1; + output[NUM_WNAF_SLICES - idx - 1] = previous_slice; + } + previous_slice = wnaf_slice; + + // downshift raw_slice by 4 bits + scalar = scalar >> WNAF_SLICE_BITS; + } + + ASSERT(scalar == 0); + + output[0] = previous_slice; + + return output; + }; + std::vector msms; + std::vector active_msm; + + // We start pc at `num_muls` and decrement for each mul processed. + // This gives us two desired properties: + // 1: the value of pc at the 1st row = number of muls (easy to check) + // 2: the value of pc for the final mul = 1 + // The latter point is valuable as it means that we can add empty rows (where pc = 0) and still satisfy our + // sumcheck relations that involve pc (if we did the other way around, starting at 1 and ending at num_muls, + // we create a discontinuity in pc values between the last transcript row and the following empty row) + uint32_t pc = num_muls; + + const auto process_mul = [&active_msm, &pc, &compute_wnaf_slices, &compute_precomputed_table]( + const auto& scalar, const auto& base_point) { + if (scalar != 0) { + active_msm.push_back(ScalarMul{ + .pc = pc, + .scalar = scalar, + .base_point = base_point, + .wnaf_slices = compute_wnaf_slices(scalar), + .wnaf_skew = (scalar & 1) == 0, + .precomputed_table = compute_precomputed_table(base_point), + }); + pc--; + } + }; + + for (auto& op : vm_operations) { + if (op.mul) { + process_mul(op.z1, op.base_point); + process_mul(op.z2, AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y }); + + } else { + if (!active_msm.empty()) { + msms.push_back(active_msm); + active_msm = {}; + } + } + } + if (!active_msm.empty()) { + msms.push_back(active_msm); + } + + ASSERT(pc == 0); + return msms; + } + + static std::vector get_flattened_scalar_muls(const std::vector& msms) + { + std::vector result; + for (const auto& msm : msms) { + for (const auto& mul : msm) { + result.push_back(mul); + } + } + return result; + } + + void add_accumulate(const AffineElement& to_add) + { + vm_operations.emplace_back(VMOperation{ + .add = true, + .mul = false, + .eq = false, + .reset = false, + .base_point = to_add, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + } + + void mul_accumulate(const AffineElement& to_mul, const CycleScalar& scalar) + { + CycleScalar z1 = 0; + CycleScalar z2 = 0; + auto converted = scalar.from_montgomery_form(); + CycleScalar::split_into_endomorphism_scalars(converted, z1, z2); + z1 = z1.to_montgomery_form(); + z2 = z2.to_montgomery_form(); + vm_operations.emplace_back(VMOperation{ + .add = false, + .mul = true, + .eq = false, + .reset = false, + .base_point = to_mul, + .z1 = z1, + .z2 = z2, + .mul_scalar_full = scalar, + }); + } + void eq(const AffineElement& expected) + { + vm_operations.emplace_back(VMOperation{ + .add = false, + .mul = false, + .eq = true, + .reset = true, + .base_point = expected, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + } + + void empty_row() + { + vm_operations.emplace_back(VMOperation{ + .add = false, + .mul = false, + .eq = false, + .reset = false, + .base_point = CycleGroup::affine_point_at_infinity, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + } + + RawPolynomials compute_full_polynomials() + { + const auto msms = get_msms(); + const auto flattened_muls = get_flattened_scalar_muls(msms); + + std::array, 2> point_table_read_counts; + const auto transcript_state = + ECCVMTranscriptBuilder::compute_transcript_state(vm_operations, get_number_of_muls()); + const auto precompute_table_state = + ECCVMPrecomputedTablesBuilder::compute_precompute_state(flattened_muls); + const auto msm_state = + ECCVMMSMMBuilder::compute_msm_state(msms, point_table_read_counts, get_number_of_muls()); + + const size_t msm_size = msm_state.size(); + const size_t transcript_size = transcript_state.size(); + const size_t precompute_table_size = precompute_table_state.size(); + + const size_t num_rows = std::max(precompute_table_size, std::max(msm_size, transcript_size)); + + const size_t num_rows_log2 = numeric::get_msb64(num_rows); + size_t num_rows_pow2 = 1UL << (num_rows_log2 + (1UL << num_rows_log2 == num_rows ? 0 : 1)); + + RawPolynomials rows; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + rows[j] = Polynomial(num_rows_pow2); + } + + rows.lagrange_first[0] = 1; + rows.lagrange_second[1] = 1; + rows.lagrange_last[rows.lagrange_last.size() - 1] = 1; + + for (size_t i = 0; i < point_table_read_counts[0].size(); ++i) { + // TODO(@zac-williamson) explain off-by-one offset + // When computing the WNAF slice for a point at point counter value `pc` and a round index `round`, the row + // number that computes the slice can be derived. This row number is then mapped to the index of + // `lookup_read_counts`. We do this mapping in `ecc_msm_relation`. We are off-by-one because we add an empty + // row at the start of the WNAF columns that is not accounted for (index of lookup_read_counts maps to the + // row in our WNAF columns that computes a slice for a given value of pc and round) + rows.lookup_read_counts_0[i + 1] = point_table_read_counts[0][i]; + rows.lookup_read_counts_1[i + 1] = point_table_read_counts[1][i]; + } + for (size_t i = 0; i < transcript_state.size(); ++i) { + rows.transcript_accumulator_empty[i] = transcript_state[i].accumulator_empty; + rows.q_transcript_add[i] = transcript_state[i].q_add; + rows.q_transcript_mul[i] = transcript_state[i].q_mul; + rows.q_transcript_eq[i] = transcript_state[i].q_eq; + rows.transcript_q_reset_accumulator[i] = transcript_state[i].q_reset_accumulator; + rows.q_transcript_msm_transition[i] = transcript_state[i].q_msm_transition; + rows.transcript_pc[i] = transcript_state[i].pc; + rows.transcript_msm_count[i] = transcript_state[i].msm_count; + rows.transcript_x[i] = transcript_state[i].base_x; + rows.transcript_y[i] = transcript_state[i].base_y; + rows.transcript_z1[i] = transcript_state[i].z1; + rows.transcript_z2[i] = transcript_state[i].z2; + rows.transcript_z1zero[i] = transcript_state[i].z1_zero; + rows.transcript_z2zero[i] = transcript_state[i].z2_zero; + rows.transcript_op[i] = transcript_state[i].opcode; + rows.transcript_accumulator_x[i] = transcript_state[i].accumulator_x; + rows.transcript_accumulator_y[i] = transcript_state[i].accumulator_y; + rows.transcript_msm_x[i] = transcript_state[i].msm_output_x; + rows.transcript_msm_y[i] = transcript_state[i].msm_output_y; + } + + // TODO(@zac-williamson) if final opcode resets accumulator, all subsequent "is_accumulator_empty" row values + // must be 1. Ideally we find a way to tweak this so that empty rows that do nothing have column values that are + // all zero + if (transcript_state[transcript_state.size() - 1].accumulator_empty == 1) { + for (size_t i = transcript_state.size(); i < num_rows_pow2; ++i) { + rows.transcript_accumulator_empty[i] = 1; + } + } + for (size_t i = 0; i < precompute_table_state.size(); ++i) { + rows.q_wnaf[i] = (i != 0) ? 1 : 0; // todo document, derive etc etc // first row is empty! + rows.table_pc[i] = precompute_table_state[i].pc; + rows.table_point_transition[i] = static_cast(precompute_table_state[i].point_transition); + // rows.table_point_transition_shift = static_cast(table_state[i].point_transition); + rows.table_round[i] = precompute_table_state[i].round; + rows.table_scalar_sum[i] = precompute_table_state[i].scalar_sum; + + rows.table_s1[i] = precompute_table_state[i].s1; + rows.table_s2[i] = precompute_table_state[i].s2; + rows.table_s3[i] = precompute_table_state[i].s3; + rows.table_s4[i] = precompute_table_state[i].s4; + rows.table_s5[i] = precompute_table_state[i].s5; + rows.table_s6[i] = precompute_table_state[i].s6; + rows.table_s7[i] = precompute_table_state[i].s7; + rows.table_s8[i] = precompute_table_state[i].s8; + // todo explain why skew is 7 not 1 + rows.table_skew[i] = precompute_table_state[i].skew ? 7 : 0; + + rows.table_dx[i] = precompute_table_state[i].precompute_double.x; + rows.table_dy[i] = precompute_table_state[i].precompute_double.y; + rows.table_tx[i] = precompute_table_state[i].precompute_accumulator.x; + rows.table_ty[i] = precompute_table_state[i].precompute_accumulator.y; + } + + for (size_t i = 0; i < msm_state.size(); ++i) { + rows.q_msm_transition[i] = static_cast(msm_state[i].q_msm_transition); + rows.msm_q_add[i] = static_cast(msm_state[i].q_add); + rows.msm_q_double[i] = static_cast(msm_state[i].q_double); + rows.msm_q_skew[i] = static_cast(msm_state[i].q_skew); + rows.msm_accumulator_x[i] = msm_state[i].accumulator_x; + rows.msm_accumulator_y[i] = msm_state[i].accumulator_y; + rows.msm_pc[i] = msm_state[i].pc; + rows.msm_size_of_msm[i] = msm_state[i].msm_size; + rows.msm_count[i] = msm_state[i].msm_count; + rows.msm_round[i] = msm_state[i].msm_round; + rows.msm_q_add1[i] = static_cast(msm_state[i].add_state[0].add); + rows.msm_q_add2[i] = static_cast(msm_state[i].add_state[1].add); + rows.msm_q_add3[i] = static_cast(msm_state[i].add_state[2].add); + rows.msm_q_add4[i] = static_cast(msm_state[i].add_state[3].add); + rows.msm_x1[i] = msm_state[i].add_state[0].point.x; + rows.msm_y1[i] = msm_state[i].add_state[0].point.y; + rows.msm_x2[i] = msm_state[i].add_state[1].point.x; + rows.msm_y2[i] = msm_state[i].add_state[1].point.y; + rows.msm_x3[i] = msm_state[i].add_state[2].point.x; + rows.msm_y3[i] = msm_state[i].add_state[2].point.y; + rows.msm_x4[i] = msm_state[i].add_state[3].point.x; + rows.msm_y4[i] = msm_state[i].add_state[3].point.y; + rows.msm_collision_x1[i] = msm_state[i].add_state[0].collision_inverse; + rows.msm_collision_x2[i] = msm_state[i].add_state[1].collision_inverse; + rows.msm_collision_x3[i] = msm_state[i].add_state[2].collision_inverse; + rows.msm_collision_x4[i] = msm_state[i].add_state[3].collision_inverse; + rows.msm_lambda1[i] = msm_state[i].add_state[0].lambda; + rows.msm_lambda2[i] = msm_state[i].add_state[1].lambda; + rows.msm_lambda3[i] = msm_state[i].add_state[2].lambda; + rows.msm_lambda4[i] = msm_state[i].add_state[3].lambda; + rows.msm_slice1[i] = msm_state[i].add_state[0].slice; + rows.msm_slice2[i] = msm_state[i].add_state[1].slice; + rows.msm_slice3[i] = msm_state[i].add_state[2].slice; + rows.msm_slice4[i] = msm_state[i].add_state[3].slice; + } + + rows.q_transcript_mul_shift = typename Flavor::Polynomial(rows.q_transcript_mul.shifted()); + rows.q_transcript_accumulate_shift = typename Flavor::Polynomial(rows.q_transcript_accumulate.shifted()); + rows.transcript_msm_count_shift = typename Flavor::Polynomial(rows.transcript_msm_count.shifted()); + rows.transcript_accumulator_x_shift = typename Flavor::Polynomial(rows.transcript_accumulator_x.shifted()); + rows.transcript_accumulator_y_shift = typename Flavor::Polynomial(rows.transcript_accumulator_y.shifted()); + rows.table_scalar_sum_shift = typename Flavor::Polynomial(rows.table_scalar_sum.shifted()); + rows.table_dx_shift = typename Flavor::Polynomial(rows.table_dx.shifted()); + rows.table_dy_shift = typename Flavor::Polynomial(rows.table_dy.shifted()); + rows.table_tx_shift = typename Flavor::Polynomial(rows.table_tx.shifted()); + rows.table_ty_shift = typename Flavor::Polynomial(rows.table_ty.shifted()); + rows.q_msm_transition_shift = typename Flavor::Polynomial(rows.q_msm_transition.shifted()); + rows.msm_q_add_shift = typename Flavor::Polynomial(rows.msm_q_add.shifted()); + rows.msm_q_double_shift = typename Flavor::Polynomial(rows.msm_q_double.shifted()); + rows.msm_q_skew_shift = typename Flavor::Polynomial(rows.msm_q_skew.shifted()); + rows.msm_accumulator_x_shift = typename Flavor::Polynomial(rows.msm_accumulator_x.shifted()); + rows.msm_accumulator_y_shift = typename Flavor::Polynomial(rows.msm_accumulator_y.shifted()); + rows.msm_count_shift = typename Flavor::Polynomial(rows.msm_count.shifted()); + rows.msm_round_shift = typename Flavor::Polynomial(rows.msm_round.shifted()); + rows.msm_q_add1_shift = typename Flavor::Polynomial(rows.msm_q_add1.shifted()); + rows.msm_pc_shift = typename Flavor::Polynomial(rows.msm_pc.shifted()); + rows.table_pc_shift = typename Flavor::Polynomial(rows.table_pc.shifted()); + rows.transcript_pc_shift = typename Flavor::Polynomial(rows.transcript_pc.shifted()); + rows.table_round_shift = typename Flavor::Polynomial(rows.table_round.shifted()); + rows.transcript_accumulator_empty_shift = + typename Flavor::Polynomial(rows.transcript_accumulator_empty.shifted()); + rows.q_wnaf_shift = typename Flavor::Polynomial(rows.q_wnaf.shifted()); + return rows; + } + + bool check_circuit() + { + const FF gamma = FF::random_element(); + const FF eta = FF::random_element(); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + proof_system::honk::sumcheck::RelationParameters params{ + .eta = eta, + .beta = 0, + .gamma = gamma, + .public_input_delta = 0, + .lookup_grand_product_delta = 0, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + + auto rows = compute_full_polynomials(); + const size_t num_rows = rows[0].size(); + proof_system::honk::lookup_library::compute_logderivative_inverse>( + rows, params, num_rows); + + honk::permutation_library::compute_permutation_grand_product>( + num_rows, rows, params); + + rows.z_perm_shift = typename Flavor::Polynomial(rows.z_perm.shifted()); + + const auto evaluate_relation = [&](const std::string& relation_name) { + auto relation = Relation(); + typename Relation::RelationValues result; + for (auto& r : result) { + r = 0; + } + constexpr size_t NUM_SUBRELATIONS = result.size(); + + for (size_t i = 0; i < num_rows; ++i) { + typename Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + relation.add_full_relation_value_contribution(result, row, params, 1); + + bool x = true; + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + if (result[j] != 0) { + info("Relation ", relation_name, ", subrelation index ", j, " failed at row ", i); + x = false; + } + } + if (!x) { + return false; + } + } + return true; + }; + + bool result = true; + result = result && evaluate_relation.template operator()>( + "ECCVMTranscriptRelation"); + result = result && evaluate_relation.template operator()>( + "ECCVMPointTableRelation"); + result = + result && evaluate_relation.template operator()>("ECCVMWnafRelation"); + result = + result && evaluate_relation.template operator()>("ECCVMMSMRelation"); + result = + result && evaluate_relation.template operator()>("ECCVMSetRelation"); + + auto lookup_relation = honk::sumcheck::ECCVMLookupRelation(); + typename honk::sumcheck::ECCVMLookupRelation::RelationValues lookup_result; + for (auto& r : lookup_result) { + r = 0; + } + for (size_t i = 0; i < num_rows; ++i) { + typename Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + { + lookup_relation.add_full_relation_value_contribution(lookup_result, row, params, 1); + } + } + for (auto r : lookup_result) { + if (r != 0) { + info("Relation ECCVMLookupRelation failed."); + return false; + } + } + return result; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.test.cpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.test.cpp new file mode 100644 index 0000000000..28a8920126 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.test.cpp @@ -0,0 +1,227 @@ +#include "barretenberg/crypto/generators/generator_data.hpp" +#include "barretenberg/crypto/pedersen_commitment/pedersen.hpp" +#include "eccvm_circuit_builder.hpp" +#include + +using namespace barretenberg; + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +namespace eccvm_circuit_builder_tests { + +TEST(ECCVMCircuitConstructor, BaseCase) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::g1::element b = grumpkin::get_generator(1); + grumpkin::g1::element c = grumpkin::get_generator(2); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + grumpkin::fr y = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x) + a + a + (b * y) + (b * x) + (b * x); + grumpkin::g1::element expected_2 = (a * x) + c + (b * x); + + circuit.add_accumulate(a); + circuit.mul_accumulate(a, x); + circuit.mul_accumulate(b, x); + circuit.mul_accumulate(b, y); + circuit.add_accumulate(a); + circuit.mul_accumulate(b, x); + circuit.eq(expected_1); + circuit.add_accumulate(c); + circuit.mul_accumulate(a, x); + circuit.mul_accumulate(b, x); + circuit.eq(expected_2); + circuit.mul_accumulate(a, x); + circuit.mul_accumulate(b, x); + circuit.mul_accumulate(c, x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, Add) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + + circuit.add_accumulate(a); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, Mul) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.mul_accumulate(a, x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, ShortMul) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + uint256_t small_x = 0; + // make sure scalar is less than 127 bits to fit in z1 + small_x.data[0] = engine.get_random_uint64(); + small_x.data[1] = engine.get_random_uint64() & 0xFFFFFFFFFFFFULL; + grumpkin::fr x = small_x; + + circuit.mul_accumulate(a, x); + circuit.eq(a * small_x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EqFails) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.mul_accumulate(a, x); + circuit.eq(a); + bool result = circuit.check_circuit(); + EXPECT_EQ(result, false); +} + +TEST(ECCVMCircuitConstructor, EmptyRow) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x); + + circuit.empty_row(); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EmptyRowBetweenOps) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x); + + circuit.mul_accumulate(a, x); + circuit.empty_row(); + circuit.eq(expected_1); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithEq) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x); + + circuit.mul_accumulate(a, x); + circuit.eq(expected_1); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithAdd) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x); + + circuit.mul_accumulate(a, x); + circuit.eq(expected_1); + circuit.add_accumulate(a); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithMul) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.add_accumulate(a); + circuit.eq(a); + circuit.mul_accumulate(a, x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithNoop) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.add_accumulate(a); + circuit.eq(a); + circuit.mul_accumulate(a, x); + circuit.empty_row(); + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, MSM) +{ + const auto try_msms = [&](const size_t num_msms, auto& circuit) { + std::vector points; + std::vector scalars; + grumpkin::g1::element expected = grumpkin::g1::point_at_infinity; + for (size_t i = 0; i < num_msms; ++i) { + points.emplace_back(grumpkin::get_generator(i)); + scalars.emplace_back(grumpkin::fr::random_element(&engine)); + expected += (points[i] * scalars[i]); + circuit.mul_accumulate(points[i], scalars[i]); + } + circuit.eq(expected); + }; + + // single msms + for (size_t j = 1; j < 9; ++j) { + proof_system::ECCVMCircuitConstructor circuit; + try_msms(j, circuit); + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); + } + // chain msms + proof_system::ECCVMCircuitConstructor circuit; + for (size_t j = 1; j < 9; ++j) { + try_msms(j, circuit); + } + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} +} // namespace eccvm_circuit_builder_tests \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/msm_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/msm_builder.hpp new file mode 100644 index 0000000000..a4da85d015 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/msm_builder.hpp @@ -0,0 +1,263 @@ +#pragma once + +#include + +#include "./eccvm_builder_types.hpp" + +namespace proof_system { + +template class ECCVMMSMMBuilder { + public: + using CycleGroup = typename Flavor::CycleGroup; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + + static constexpr size_t ADDITIONS_PER_ROW = proof_system_eccvm::ADDITIONS_PER_ROW; + static constexpr size_t NUM_SCALAR_BITS = proof_system_eccvm::NUM_SCALAR_BITS; + static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS; + + struct MSMState { + uint32_t pc = 0; + uint32_t msm_size = 0; + uint32_t msm_count = 0; + uint32_t msm_round = 0; + bool q_msm_transition = false; + bool q_add = false; + bool q_double = false; + bool q_skew = false; + + struct AddState { + bool add = false; + int slice = 0; + AffineElement point{ 0, 0 }; + FF lambda = 0; + FF collision_inverse = 0; + }; + std::array add_state{ AddState{ false, 0, { 0, 0 }, 0, 0 }, + AddState{ false, 0, { 0, 0 }, 0, 0 }, + AddState{ false, 0, { 0, 0 }, 0, 0 }, + AddState{ false, 0, { 0, 0 }, 0, 0 } }; + FF accumulator_x = 0; + FF accumulator_y = 0; + }; + + static std::vector compute_msm_state(const std::vector>& msms, + std::array, 2>& point_table_read_counts, + const uint32_t total_number_of_muls) + { + // when we define our point lookup table, we have 2 write columns and 4 read columns + // when we perform a read on a given row, we need to increment the read count on the respective write column by + // 1 we can define the following struture: 1st write column = positive 2nd write column = negative the row + // number is a function of pc and slice value row = pc_delta * rows_per_point_table + some function of the slice + // value pc_delta = total_number_of_muls - pc std::vector point_table_read_counts; + const size_t table_rows = static_cast(total_number_of_muls) * 8; + point_table_read_counts[0].reserve(table_rows); + point_table_read_counts[1].reserve(table_rows); + for (size_t i = 0; i < table_rows; ++i) { + point_table_read_counts[0].emplace_back(0); + point_table_read_counts[1].emplace_back(0); + } + const auto update_read_counts = [&](const size_t pc, const int slice) { + // When we compute our wnaf/point tables, we start with the point with the largest pc value. + // i.e. if we are reading a slice for point with a point counter value `pc`, + // its position in the wnaf/point table (relative to other points) will be `total_number_of_muls - pc` + const size_t pc_delta = total_number_of_muls - pc; + const size_t pc_offset = pc_delta * 8; + bool slice_negative = slice < 0; + const int slice_row = (slice + 15) / 2; + + const size_t column_index = slice_negative ? 1 : 0; + + if (slice_negative) { + point_table_read_counts[column_index][pc_offset + static_cast(slice_row)]++; + } else { + // 8 maps to 7 + // 15 maps to 0 + + // 15 - x + point_table_read_counts[column_index][pc_offset + 15 - static_cast(slice_row)]++; + } + // slice : row + // -15 : 0 + // -13 : 1 + // -11 : 2 + // -9 : 3 + // -7 : 4 + // -5 : 5 + // -3 : 6 + // -1 : 7 + // 1 : 8 + // 3 : 9 + // 5 : 10 + // 7 : 11 + // 9 : 12 + // 11 : 13 + // 13 : 14 + // 15 : 15 + }; + std::vector msm_state; + // start with empty row (shiftable polynomials must have 0 as first coefficient) + msm_state.emplace_back(MSMState{}); + uint32_t pc = total_number_of_muls; + AffineElement accumulator = CycleGroup::affine_point_at_infinity; + + for (const auto& msm : msms) { + const size_t msm_size = msm.size(); + + const size_t rows_per_round = (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + static constexpr size_t num_rounds = NUM_SCALAR_BITS / WNAF_SLICE_BITS; + + const auto add_points = [](auto& P1, auto& P2, auto& lambda, auto& collision_inverse, bool predicate) { + lambda = predicate ? (P2.y - P1.y) / (P2.x - P1.x) : 0; + collision_inverse = predicate ? (P2.x - P1.x).invert() : 0; + auto x3 = predicate ? lambda * lambda - (P2.x + P1.x) : P1.x; + auto y3 = predicate ? lambda * (P1.x - x3) - P1.y : P1.y; + return AffineElement(x3, y3); + }; + for (size_t j = 0; j < num_rounds; ++j) { + for (size_t k = 0; k < rows_per_round; ++k) { + MSMState row; + const size_t points_per_row = + (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; + const size_t idx = k * ADDITIONS_PER_ROW; + row.q_msm_transition = (j == 0) && (k == 0); + + AffineElement acc(accumulator); + Element acc_expected = accumulator; + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.add = points_per_row > m; + int slice = add_state.add ? msm[idx + m].wnaf_slices[j] : 0; + add_state.slice = add_state.add ? (slice + 15) / 2 : 0; + add_state.point = add_state.add + ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add); + + auto& p1 = (m == 0) ? add_state.point : acc; + auto& p2 = (m == 0) ? acc : add_state.point; + + acc_expected = add_predicate ? (acc_expected + add_state.point) : Element(p1); + if (add_state.add) { + update_read_counts(pc - idx - m, slice); + } + acc = add_points(p1, p2, add_state.lambda, add_state.collision_inverse, add_predicate); + ASSERT(acc == AffineElement(acc_expected)); + } + row.q_add = true; + row.q_double = false; + row.q_skew = false; + row.msm_round = static_cast(j); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(idx); + row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + row.pc = pc; + accumulator = acc; + msm_state.push_back(row); + } + if (j < num_rounds - 1) { + MSMState row; + row.q_msm_transition = false; + row.msm_round = static_cast(j + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(0); + row.q_add = false; + row.q_double = true; + row.q_skew = false; + + auto dx = accumulator.x; + auto dy = accumulator.y; + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.add = false; + add_state.slice = 0; + add_state.point = { 0, 0 }; + add_state.collision_inverse = 0; + add_state.lambda = ((dx + dx + dx) * dx) / (dy + dy); + auto x3 = add_state.lambda.sqr() - dx - dx; + dy = add_state.lambda * (dx - x3) - dy; + dx = x3; + } + + row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + accumulator = Element(accumulator).dbl().dbl().dbl().dbl(); + row.pc = pc; + msm_state.push_back(row); + } else { + for (size_t k = 0; k < rows_per_round; ++k) { + MSMState row; + + const size_t points_per_row = + (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; + const size_t idx = k * ADDITIONS_PER_ROW; + row.q_msm_transition = false; + + AffineElement acc(accumulator); + Element acc_expected = accumulator; + + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.add = points_per_row > m; + add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; + + add_state.point = add_state.add + ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + if (add_state.add) { + update_read_counts(pc - idx - m, msm[idx + m].wnaf_skew ? -1 : -15); + } + acc = add_points( + acc, add_state.point, add_state.lambda, add_state.collision_inverse, add_predicate); + acc_expected = add_predicate ? (acc_expected + add_state.point) : acc_expected; + ASSERT(acc == AffineElement(acc_expected)); + } + row.q_add = false; + row.q_double = false; + row.q_skew = true; + row.msm_round = static_cast(j + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(idx); + + row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + + row.pc = pc; + accumulator = acc; + msm_state.emplace_back(row); + } + } + } + pc -= static_cast(msm_size); + // Validate our computed accumulator matches the real MSM result! + Element expected = CycleGroup::point_at_infinity; + for (size_t i = 0; i < msm.size(); ++i) { + expected += (Element(msm[i].base_point) * msm[i].scalar); + } + // Validate the accumulator is correct! + ASSERT(accumulator == AffineElement(expected)); + } + + MSMState final_row; + final_row.pc = pc; + final_row.q_msm_transition = true; + final_row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + final_row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + final_row.msm_size = 0; + final_row.msm_count = 0; + final_row.q_add = false; + final_row.q_double = false; + final_row.q_skew = false; + final_row.add_state = { typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } }; + + msm_state.emplace_back(final_row); + return msm_state; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/precomputed_tables_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/precomputed_tables_builder.hpp new file mode 100644 index 0000000000..27c9cf48df --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/precomputed_tables_builder.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include "./eccvm_builder_types.hpp" + +namespace proof_system { + +template class ECCVMPrecomputedTablesBuilder { + public: + using CycleGroup = typename Flavor::CycleGroup; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + + static constexpr size_t NUM_WNAF_SLICES = proof_system_eccvm::NUM_WNAF_SLICES; + static constexpr size_t WNAF_SLICES_PER_ROW = proof_system_eccvm::WNAF_SLICES_PER_ROW; + static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS; + + struct PrecomputeState { + int s1 = 0; + int s2 = 0; + int s3 = 0; + int s4 = 0; + int s5 = 0; + int s6 = 0; + int s7 = 0; + int s8 = 0; + bool skew = false; + bool point_transition = false; + uint32_t pc = 0; + uint32_t round = 0; + uint256_t scalar_sum = 0; + AffineElement precompute_accumulator{ 0, 0 }; + AffineElement precompute_double{ 0, 0 }; + }; + + static std::vector compute_precompute_state( + const std::vector>& ecc_muls) + { + std::vector precompute_state; + + // start with empty row (shiftable polynomials must have 0 as first coefficient) + precompute_state.push_back(PrecomputeState{}); + static constexpr size_t num_rows_per_scalar = NUM_WNAF_SLICES / WNAF_SLICES_PER_ROW; + + // current impl doesn't work if not 4 + static_assert(WNAF_SLICES_PER_ROW == 4); + + for (const auto& entry : ecc_muls) { + const auto& slices = entry.wnaf_slices; + uint256_t scalar_sum = 0; + + const Element point = entry.base_point; + const Element d2 = point.dbl(); + + for (size_t i = 0; i < num_rows_per_scalar; ++i) { + PrecomputeState row; + const int slice0 = slices[i * WNAF_SLICES_PER_ROW]; + const int slice1 = slices[i * WNAF_SLICES_PER_ROW + 1]; + const int slice2 = slices[i * WNAF_SLICES_PER_ROW + 2]; + const int slice3 = slices[i * WNAF_SLICES_PER_ROW + 3]; + + const int slice0base2 = (slice0 + 15) / 2; + const int slice1base2 = (slice1 + 15) / 2; + const int slice2base2 = (slice2 + 15) / 2; + const int slice3base2 = (slice3 + 15) / 2; + + // convert into 2-bit chunks + row.s1 = slice0base2 >> 2; + row.s2 = slice0base2 & 3; + row.s3 = slice1base2 >> 2; + row.s4 = slice1base2 & 3; + row.s5 = slice2base2 >> 2; + row.s6 = slice2base2 & 3; + row.s7 = slice3base2 >> 2; + row.s8 = slice3base2 & 3; + bool last_row = (i == num_rows_per_scalar - 1); + + row.skew = last_row ? entry.wnaf_skew : false; + + row.scalar_sum = scalar_sum; + + // TODO(@zac-williamson). If 1st row do we apply constraint that requires slice0 to be positive? + // Need this if we want to rule out negative values (i.e. input has not yet been range + // constrained) + const int row_chunk = slice3 + slice2 * (1 << 4) + slice1 * (1 << 8) + slice0 * (1 << 12); + + bool chunk_negative = row_chunk < 0; + + scalar_sum = scalar_sum << (WNAF_SLICE_BITS * WNAF_SLICES_PER_ROW); + if (chunk_negative) { + scalar_sum -= static_cast(-row_chunk); + } else { + scalar_sum += static_cast(row_chunk); + } + row.round = static_cast(i); + row.point_transition = last_row; + row.pc = entry.pc; + + if (last_row) { + ASSERT(scalar_sum - entry.wnaf_skew == entry.scalar); + } + + row.precompute_double = d2; + // fill accumulator in reverse order i.e. first row = 15[P], then 13[P], ..., 1[P] + row.precompute_accumulator = entry.precomputed_table[proof_system_eccvm::POINT_TABLE_SIZE - 1 - i]; + precompute_state.emplace_back(row); + } + } + return precompute_state; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/transcript_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/transcript_builder.hpp new file mode 100644 index 0000000000..74a2b6ab8d --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/transcript_builder.hpp @@ -0,0 +1,175 @@ +#pragma once + +#include "./eccvm_builder_types.hpp" + +namespace proof_system { + +template class ECCVMTranscriptBuilder { + public: + using CycleGroup = typename Flavor::CycleGroup; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + + struct TranscriptState { + bool accumulator_empty = false; + bool q_add = false; + bool q_mul = false; + bool q_eq = false; + bool q_reset_accumulator = false; + bool q_msm_transition = false; + uint32_t pc = 0; + uint32_t msm_count = 0; + FF base_x = 0; + FF base_y = 0; + uint256_t z1 = 0; + uint256_t z2 = 0; + bool z1_zero = false; + bool z2_zero = false; + uint32_t opcode = 0; + FF accumulator_x = 0; + FF accumulator_y = 0; + FF msm_output_x = 0; + FF msm_output_y = 0; + }; + struct VMState { + uint32_t pc = 0; + uint32_t count = 0; + AffineElement accumulator = CycleGroup::affine_point_at_infinity; + AffineElement msm_accumulator = CycleGroup::affine_point_at_infinity; + bool is_accumulator_empty = true; + }; + struct Opcode { + bool add; + bool mul; + bool eq; + bool reset; + [[nodiscard]] uint32_t value() const + { + auto res = static_cast(add); + res += res; + res += static_cast(mul); + res += res; + res += static_cast(eq); + res += res; + res += static_cast(reset); + return res; + } + }; + static std::vector compute_transcript_state( + const std::vector>& vm_operations, + const uint32_t total_number_of_muls) + { + std::vector transcript_state; + VMState state{ + .pc = total_number_of_muls, + .count = 0, + .accumulator = CycleGroup::affine_point_at_infinity, + .msm_accumulator = CycleGroup::affine_point_at_infinity, + .is_accumulator_empty = true, + }; + VMState updated_state; + + // add an empty row. 1st row all zeroes because of our shiftable polynomials + transcript_state.emplace_back(TranscriptState{}); + for (size_t i = 0; i < vm_operations.size(); ++i) { + TranscriptState row; + const proof_system_eccvm::VMOperation& entry = vm_operations[i]; + + const bool is_mul = entry.mul; + const bool z1_zero = (entry.mul) ? entry.z1 == 0 : true; + const bool z2_zero = (entry.mul) ? entry.z2 == 0 : true; + const uint32_t num_muls = is_mul ? (static_cast(!z1_zero) + static_cast(!z2_zero)) : 0; + + updated_state = state; + + if (entry.reset) { + updated_state.is_accumulator_empty = true; + updated_state.msm_accumulator = CycleGroup::affine_point_at_infinity; + } + updated_state.pc = state.pc - num_muls; + + bool last_row = i == (vm_operations.size() - 1); + // msm transition = current row is doing a lookup to validate output = msm output + // i.e. next row is not part of MSM and current row is part of MSM + // or next row is irrelevent and current row is a straight MUL + bool next_not_msm = last_row ? true : !vm_operations[i + 1].mul; + + bool msm_transition = entry.mul && next_not_msm; + + // we reset the count in updated state if we are not accumulating and not doing an msm + bool current_msm = entry.mul; + bool current_ongoing_msm = entry.mul && !next_not_msm; + updated_state.count = current_ongoing_msm ? state.count + num_muls : 0; + + if (current_msm) { + const auto P = grumpkin::g1::element(entry.base_point); + const auto R = grumpkin::g1::element(state.msm_accumulator); + updated_state.msm_accumulator = R + P * entry.mul_scalar_full; + } + + if (entry.mul && next_not_msm) { + if (state.is_accumulator_empty) { + updated_state.accumulator = updated_state.msm_accumulator; + } else { + const auto R = grumpkin::g1::element(state.accumulator); + updated_state.accumulator = R + updated_state.msm_accumulator; + } + updated_state.is_accumulator_empty = false; + } + + bool add_accumulate = entry.add; + if (add_accumulate) { + if (state.is_accumulator_empty) { + + updated_state.accumulator = entry.base_point; + } else { + updated_state.accumulator = grumpkin::g1::element(state.accumulator) + entry.base_point; + } + updated_state.is_accumulator_empty = false; + } + row.accumulator_empty = state.is_accumulator_empty; + row.q_add = entry.add; + row.q_mul = entry.mul; + row.q_eq = entry.eq; + row.q_reset_accumulator = entry.reset; + row.q_msm_transition = msm_transition; + row.pc = state.pc; + row.msm_count = state.count; + row.base_x = (entry.add || entry.mul || entry.eq) ? entry.base_point.x : 0; + row.base_y = (entry.add || entry.mul || entry.eq) ? entry.base_point.y : 0; + row.z1 = (entry.mul) ? entry.z1 : 0; + row.z2 = (entry.mul) ? entry.z2 : 0; + row.z1_zero = z1_zero; + row.z2_zero = z2_zero; + row.opcode = Opcode{ .add = entry.add, .mul = entry.mul, .eq = entry.eq, .reset = entry.reset }.value(); + row.accumulator_x = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.x; + row.accumulator_y = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.y; + row.msm_output_x = + msm_transition + ? (updated_state.msm_accumulator.is_point_at_infinity() ? 0 : updated_state.msm_accumulator.x) + : 0; + row.msm_output_y = + msm_transition + ? (updated_state.msm_accumulator.is_point_at_infinity() ? 0 : updated_state.msm_accumulator.y) + : 0; + + state = updated_state; + + if (entry.mul && next_not_msm) { + state.msm_accumulator = CycleGroup::affine_point_at_infinity; + } + transcript_state.emplace_back(row); + } + + TranscriptState final_row; + final_row.pc = updated_state.pc; + final_row.accumulator_x = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.x; + final_row.accumulator_y = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.y; + final_row.accumulator_empty = updated_state.is_accumulator_empty; + + transcript_state.push_back(final_row); + return transcript_state; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/flavor/flavor.hpp b/cpp/src/barretenberg/proof_system/flavor/flavor.hpp index 38cb073e7d..9477ea84ba 100644 --- a/cpp/src/barretenberg/proof_system/flavor/flavor.hpp +++ b/cpp/src/barretenberg/proof_system/flavor/flavor.hpp @@ -275,6 +275,8 @@ class Standard; class StandardGrumpkin; class Ultra; class UltraGrumpkin; +class ECCVM; +class ECCVMGrumpkin; } // namespace proof_system::honk::flavor // Forward declare plonk flavors @@ -305,5 +307,7 @@ template concept StandardFlavor = IsAnyOf concept UltraFlavor = IsAnyOf; +template concept ECCVMFlavor = IsAnyOf; + // clang-format on } // namespace proof_system diff --git a/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.cpp b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.cpp new file mode 100644 index 0000000000..f703dc8067 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.cpp @@ -0,0 +1,89 @@ +#include "ecc_msm_relation.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp" +#include "barretenberg/honk/flavor/ecc_vm.hpp" + +namespace proof_system::honk::sumcheck { + +/** + * @brief Expression for the StandardArithmetic gate. + * @details The relation is defined as C(extended_edges(X)...) = + * (q_m * w_r * w_l) + (q_l * w_l) + (q_r * w_r) + (q_o * w_o) + q_c + * + * @param evals transformed to `evals + C(extended_edges(X)...)*scaling_factor` + * @param extended_edges an std::array containing the fully extended Accumulator edges. + * @param parameters contains beta, gamma, and public_input_delta, .... + * @param scaling_factor optional term to scale the evaluation before adding to evals. + */ +template +template +void ECCVMLookupRelationBase::add_edge_contribution_impl(typename AccumulatorTypes::Accumulators& accumulator, + const auto& extended_edges, + const RelationParameters& relation_params, + const FF& /*unused*/) const +{ + using View = typename std::tuple_element<0, typename AccumulatorTypes::AccumulatorViews>::type; + using Accumulator = typename std::tuple_element<0, typename AccumulatorTypes::Accumulators>::type; + + auto lookup_inverses = View(extended_edges.lookup_inverses); + + constexpr size_t NUM_TOTAL_TERMS = READ_TERMS + WRITE_TERMS; + std::array lookup_terms; + std::array denominator_accumulator; + + // The lookup relation = \sum_j (1 / read_term[j]) - \sum_k (read_counts[k] / write_term[k]) + // To get the inverses (1 / read_term[i]), (1 / write_term[i]), we have a commitment to the product of all inverses + // i.e. lookup_inverse = \prod_j (1 / read_term[j]) * \prod_k (1 / write_term[k]) + // The purpose of this next section is to derive individual inverse terms using `lookup_inverses` + // i.e. (1 / read_term[i]) = lookup_inverse * \prod_{j /ne i} (read_term[j]) * \prod_k (write_term[k]) + // (1 / write_term[i]) = lookup_inverse * \prod_j (read_term[j]) * \prod_{k ne i} (write_term[k]) + barretenberg::constexpr_for<0, READ_TERMS, 1>([&]() { + lookup_terms[i] = compute_read_term(extended_edges, relation_params, 0); + }); + barretenberg::constexpr_for<0, WRITE_TERMS, 1>([&]() { + lookup_terms[i + READ_TERMS] = compute_write_term(extended_edges, relation_params, 0); + }); + + barretenberg::constexpr_for<0, NUM_TOTAL_TERMS, 1>( + [&]() { denominator_accumulator[i] = lookup_terms[i]; }); + + barretenberg::constexpr_for<0, NUM_TOTAL_TERMS - 1, 1>( + [&]() { denominator_accumulator[i + 1] *= denominator_accumulator[i]; }); + + Accumulator inverse_accumulator = Accumulator(lookup_inverses); // denominator_accumulator[NUM_TOTAL_TERMS - 1]; + + const auto row_has_write = View(extended_edges.q_wnaf); + const auto row_has_read = View(extended_edges.msm_q_add + extended_edges.msm_q_skew); + const auto inverse_exists = row_has_write + row_has_read - (row_has_write * row_has_read); + + std::get<1>(accumulator) += denominator_accumulator[NUM_TOTAL_TERMS - 1] * lookup_inverses - inverse_exists; + + // After this algo, total degree of denominator_accumulator = NUM_TOTAL_TERMA + for (size_t i = 0; i < NUM_TOTAL_TERMS - 1; ++i) { + denominator_accumulator[NUM_TOTAL_TERMS - 1 - i] = + denominator_accumulator[NUM_TOTAL_TERMS - 2 - i] * inverse_accumulator; + inverse_accumulator = inverse_accumulator * lookup_terms[NUM_TOTAL_TERMS - 1 - i]; + } + denominator_accumulator[0] = inverse_accumulator; + + // each predicate is degree-1 + // degree of relation at this point = NUM_TOTAL_TERMS + 1 + barretenberg::constexpr_for<0, READ_TERMS, 1>([&]() { + std::get<0>(accumulator) += + compute_read_term_predicate(extended_edges, relation_params, 0) * + denominator_accumulator[i]; + }); + + // each predicate is degree-1, `lookup_read_counts` is degree-1 + // degree of relation = NUM_TOTAL_TERMS + 2 = 6 + 2 + barretenberg::constexpr_for<0, WRITE_TERMS, 1>([&]() { + const auto p = compute_write_term_predicate(extended_edges, relation_params, 0); + const auto lookup_read_count = View(extended_edges.template lookup_read_counts()); + std::get<0>(accumulator) -= p * (denominator_accumulator[i + READ_TERMS] * lookup_read_count); + }); +} +template class ECCVMLookupRelationBase; +DEFINE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVM); +DEFINE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVMGrumpkin); + +} // namespace proof_system::honk::sumcheck diff --git a/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp new file mode 100644 index 0000000000..aaf1b905a8 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp @@ -0,0 +1,259 @@ +#pragma once +#include +#include + +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/common/constexpr_utils.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_types.hpp" +#include "barretenberg/honk/sumcheck/polynomials/univariate.hpp" + +namespace proof_system::honk::sumcheck { + +template class ECCVMLookupRelationBase { + public: + static constexpr size_t READ_TERMS = 4; + static constexpr size_t WRITE_TERMS = 2; + // 1 + polynomial degree of this relation + static constexpr size_t RELATION_LENGTH = READ_TERMS + WRITE_TERMS + 3; // 9 + + static constexpr size_t LEN_1 = RELATION_LENGTH; // grand product construction sub-relation + static constexpr size_t LEN_2 = RELATION_LENGTH; // left-shiftable polynomial sub-relation + template