Skip to content

Commit

Permalink
feat: Optimize memory consumption of pedersen generators (#413)
Browse files Browse the repository at this point in the history
Co-authored-by: Charlie Lye <karl.lye@gmail.com>
  • Loading branch information
suyash67 and charlielye authored Jun 28, 2023
1 parent bc0844b commit d60b16a
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 87 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
- clang >= 10 or gcc >= 10
- clang-format
- libomp (if multithreading is required. Multithreading can be disabled using the compiler flag `-DMULTITHREADING 0`)
- wasm-opt (part of the [Binaryen](https://github.com/WebAssembly/binaryen) toolkit)

To install on Ubuntu, run:

```
sudo apt-get install cmake clang clang-format ninja-build binaryen
sudo apt-get install cmake clang clang-format ninja-build
```

### Installing openMP (Linux)
Expand Down
2 changes: 1 addition & 1 deletion cpp/.aztec-packages-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
master
3e16992198189112739e3710860e7d7717366108
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FROM ubuntu:kinetic
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y bash build-essential git libssl-dev cmake ninja-build curl binaryen xz-utils curl
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y bash build-essential git libssl-dev cmake ninja-build curl xz-utils curl

RUN curl https://wasmtime.dev/install.sh -sSf | bash /dev/stdin --version v3.0.1
WORKDIR /usr/src/barretenberg/cpp
Expand Down
2 changes: 1 addition & 1 deletion cpp/dockerfiles/Dockerfile.wasm-linux-clang
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FROM ubuntu:kinetic AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential wget git libssl-dev cmake ninja-build curl binaryen
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential wget git libssl-dev cmake ninja-build curl
RUN curl https://wasmtime.dev/install.sh -sSf | bash /dev/stdin --version v3.0.1
WORKDIR /usr/src/barretenberg/cpp
COPY ./scripts/install-wasi-sdk.sh ./scripts/install-wasi-sdk.sh
Expand Down
186 changes: 109 additions & 77 deletions cpp/src/barretenberg/crypto/generators/generator_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,64 @@ namespace crypto {
namespace generators {
namespace {

// Parameters for generator table construction
struct GeneratorParameters {
size_t num_default_generators; // Number of unique base points with default main index with precomputed ladders
size_t num_hash_indices; // Number of unique hash indices
size_t num_generators_per_hash_index; // Number of generators per hash index
size_t hash_indices_generator_offset; // Offset for hash index generators
// The number of unique base points with default main index with precomputed ladders
constexpr size_t num_default_generators = 200;

/**
* @brief Contains number of hash indices all of which support a fixed number of generators per index.
*/
struct HashIndexParams {
size_t num_indices;
size_t num_generators_per_index;

/**
* @brief Computes the total number of generators for a given HashIndexParams.
*
* @return Number of generators.
*/
constexpr size_t total_generators() const { return (num_indices * num_generators_per_index); }
};

// Define BARRETENBERG_CRYPTO_GENERATOR_PARAMETERS_HACK to use custom values for generator parameters
// This hack is to avoid breakage due to generators in aztec circuits while maintaining compatibility
// with the barretenberg master.
#ifdef BARRETENBERG_CRYPTO_GENERATOR_PARAMETERS_HACK
constexpr GeneratorParameters GEN_PARAMS = { BARRETENBERG_CRYPTO_GENERATOR_PARAMETERS_HACK };
#else
#ifdef __wasm__
constexpr GeneratorParameters GEN_PARAMS = { 32, 16, 8, 2048 };
// TODO need to resolve memory out of bounds when these are too high
#else
constexpr GeneratorParameters GEN_PARAMS = { 2048, 16, 8, 2048 };
#endif
#endif

constexpr size_t num_indexed_generators = GEN_PARAMS.num_hash_indices * GEN_PARAMS.num_generators_per_hash_index;
constexpr size_t size_of_generator_data_array = GEN_PARAMS.hash_indices_generator_offset + num_indexed_generators;
constexpr HashIndexParams LOW = { 32, 8 };
constexpr HashIndexParams MID = { 8, 16 };
constexpr HashIndexParams HIGH = { 4, 44 };

constexpr size_t num_hash_indices = (LOW.num_indices + MID.num_indices + HIGH.num_indices);
constexpr size_t num_indexed_generators = LOW.total_generators() + MID.total_generators() + HIGH.total_generators();

constexpr size_t size_of_generator_data_array = num_default_generators + num_indexed_generators;
constexpr size_t num_generator_types = 3;

ladder_t g1_ladder;
bool inited = false;

void compute_fixed_base_ladder(const grumpkin::g1::affine_element& generator, ladder_t& ladder)
template <size_t ladder_length, size_t ladder_max_length>
void compute_fixed_base_ladder(const grumpkin::g1::affine_element& generator,
std::array<fixed_base_ladder, ladder_max_length>& ladder)
{
ASSERT(ladder_length <= ladder_max_length);
grumpkin::g1::element* ladder_temp =
static_cast<grumpkin::g1::element*>(aligned_alloc(64, sizeof(grumpkin::g1::element) * (quad_length * 2)));
static_cast<grumpkin::g1::element*>(aligned_alloc(64, sizeof(grumpkin::g1::element) * (ladder_length * 2)));

grumpkin::g1::element accumulator;
accumulator = grumpkin::g1::element(generator);
for (size_t i = 0; i < quad_length; ++i) {
for (size_t i = 0; i < ladder_length; ++i) {
ladder_temp[i] = accumulator;
accumulator.self_dbl();
ladder_temp[quad_length + i] = ladder_temp[i] + accumulator;
ladder_temp[ladder_length + i] = ladder_temp[i] + accumulator;
accumulator.self_dbl();
}
grumpkin::g1::element::batch_normalize(&ladder_temp[0], quad_length * 2);
for (size_t i = 0; i < quad_length; ++i) {
grumpkin::fq::__copy(ladder_temp[i].x, ladder[quad_length - 1 - i].one.x);
grumpkin::fq::__copy(ladder_temp[i].y, ladder[quad_length - 1 - i].one.y);
grumpkin::fq::__copy(ladder_temp[quad_length + i].x, ladder[quad_length - 1 - i].three.x);
grumpkin::fq::__copy(ladder_temp[quad_length + i].y, ladder[quad_length - 1 - i].three.y);
grumpkin::g1::element::batch_normalize(&ladder_temp[0], ladder_length * 2);
for (size_t i = 0; i < ladder_length; ++i) {
grumpkin::fq::__copy(ladder_temp[i].x, ladder[ladder_length - 1 - i].one.x);
grumpkin::fq::__copy(ladder_temp[i].y, ladder[ladder_length - 1 - i].one.y);
grumpkin::fq::__copy(ladder_temp[ladder_length + i].x, ladder[ladder_length - 1 - i].three.x);
grumpkin::fq::__copy(ladder_temp[ladder_length + i].y, ladder[ladder_length - 1 - i].three.y);
}

constexpr grumpkin::fq eight_inverse = grumpkin::fq{ 8, 0, 0, 0 }.to_montgomery_form().invert();
std::array<grumpkin::fq, quad_length> y_denominators;
for (size_t i = 0; i < quad_length; ++i) {
std::array<grumpkin::fq, ladder_length> y_denominators;
for (size_t i = 0; i < ladder_length; ++i) {

grumpkin::fq x_beta = ladder[i].one.x;
grumpkin::fq x_gamma = ladder[i].three.x;
Expand Down Expand Up @@ -84,8 +89,8 @@ void compute_fixed_base_ladder(const grumpkin::g1::affine_element& generator, la
ladder[i].q_y_1 = y_alpha_1;
ladder[i].q_y_2 = y_alpha_2;
}
grumpkin::fq::batch_invert(&y_denominators[0], quad_length);
for (size_t i = 0; i < quad_length; ++i) {
grumpkin::fq::batch_invert(&y_denominators[0], ladder_length);
for (size_t i = 0; i < ladder_length; ++i) {
ladder[i].q_y_1 *= y_denominators[i];
ladder[i].q_y_2 *= y_denominators[i];
}
Expand Down Expand Up @@ -125,25 +130,19 @@ auto compute_generator_data(grumpkin::g1::affine_element const& generator,
gen_data->aux_generator = aux_generator;
gen_data->skew_generator = skew_generator;

compute_fixed_base_ladder(generator, gen_data->ladder);
compute_fixed_base_ladder(aux_generator, gen_data->aux_ladder);

constexpr size_t first_generator_segment = quad_length - 2;
constexpr size_t second_generator_segment = 2;
compute_fixed_base_ladder<quad_length>(generator, gen_data->ladder);
std::array<fixed_base_ladder, aux_length> aux_ladder_temp;
compute_fixed_base_ladder<aux_length>(aux_generator, aux_ladder_temp);

for (size_t j = 0; j < first_generator_segment; ++j) {
gen_data->hash_ladder[j] = gen_data->ladder[j + (quad_length - first_generator_segment)];
}
for (size_t j = 0; j < second_generator_segment; ++j) {
gen_data->hash_ladder[j + first_generator_segment] =
gen_data->aux_ladder[j + (quad_length - second_generator_segment)];
// Fill in the aux_generator multiples in the last two indices of the ladder.
for (size_t j = 0; j < aux_length; ++j) {
gen_data->ladder[j + quad_length] = aux_ladder_temp[j];
}

return gen_data;
}

const fixed_base_ladder* get_ladder_internal(std::array<fixed_base_ladder, quad_length> const& ladder,
const size_t num_bits)
const fixed_base_ladder* get_ladder_internal(ladder_t const& ladder, const size_t num_bits, const size_t offset = 0)
{
// find n, such that 2n + 1 >= num_bits
size_t n;
Expand All @@ -155,7 +154,7 @@ const fixed_base_ladder* get_ladder_internal(std::array<fixed_base_ladder, quad_
++n;
}
}
const fixed_base_ladder* result = &ladder[quad_length - n - 1];
const fixed_base_ladder* result = &ladder[quad_length + offset - n - 1];
return result;
}

Expand Down Expand Up @@ -224,15 +223,15 @@ std::vector<std::unique_ptr<generator_data>> const& init_generator_data()

global_generator_data.resize(size_of_generator_data_array);

for (size_t i = 0; i < GEN_PARAMS.num_default_generators; i++) {
for (size_t i = 0; i < num_default_generators; i++) {
global_generator_data[i] = compute_generator_data(generators[i], aux_generators[i], skew_generators[i]);
}

for (size_t i = GEN_PARAMS.hash_indices_generator_offset; i < size_of_generator_data_array; i++) {
for (size_t i = num_default_generators; i < size_of_generator_data_array; i++) {
global_generator_data[i] = compute_generator_data(generators[i], aux_generators[i], skew_generators[i]);
}

compute_fixed_base_ladder(grumpkin::g1::one, g1_ladder);
compute_fixed_base_ladder<quad_length>(grumpkin::g1::one, g1_ladder);

inited = true;
return global_generator_data;
Expand All @@ -245,40 +244,73 @@ const fixed_base_ladder* get_g1_ladder(const size_t num_bits)
}

/**
* Generator indexing:
* @brief Returns a reference to the generator data for the specified generator index.
* The generator index is composed of an index and sub-index. The index specifies
* which hash index the generator belongs to, and the sub-index specifies the
* position of the generator within the hash index.
*
* Number of default generators (index = 0): N = 2048
* Number of hash indices: H = 32
* Number of sub indices for a given hash index: h = 64.
* Number of types of generators needed per hash index: t = 3
* The generator data is stored in a global array of generator_data objects, which
* is initialized lazily when the function is called for the first time. The global
* array includes both default generators and user-defined generators.
*
* Default generators:
* 0: P_0 P_1 P_2 ... P_{N'-1}
* If the specified index is 0, the sub-index is used to look up the corresponding
* default generator in the global array. Otherwise, the global index of the generator
* is calculated based on the index and sub-index, and used to look up the corresponding
* user-defined generator in the global array.
*
* Hash-index dependent generators: (let N' = t * N)
* 1: P_{N' + 0*h*t} P_{N' + 0*h*t + 1*t} ... P_{N' + 0*h*t + (h-1)*t}
* 2: P_{N' + 1*h*t} P_{N' + 1*h*t + 1*t} ... P_{N' + 1*h*t + (h-1)*t}
* 2: P_{N' + 2*h*t} P_{N' + 2*h*t + 1*t} ... P_{N' + 2*h*t + (h-1)*t}
* 4:
* .
* .
* .
* H-1: P_{N' + (H-2)*h*t} P_{N' + (H-2)*h*t + 1*t} ... P_{N' + (H-2)*h*t + (h-1)*t}
* H : P_{N' + (H-1)*h*t} P_{N' + (H-1)*h*t + 1*t} ... P_{N' + (H-1)*h*t + (h-1)*t}
* The function throws an exception if the specified index is invalid.
*
* Total generators = (N + H * h) * t = 2304
* @param index The generator index, consisting of an index and sub-index.
* @return A reference to the generator data for the specified generator index.
* @throws An exception if the specified index is invalid.
*
* @note TODO: Write a generator indexing example
*/
generator_data const& get_generator_data(generator_index_t index)
{
// Initialize the global array of generator data
auto& global_generator_data = init_generator_data();

// Handle default generators
if (index.index == 0) {
ASSERT(index.sub_index < GEN_PARAMS.num_default_generators);
ASSERT(index.sub_index < num_default_generators);
return *global_generator_data[index.sub_index];
}
ASSERT(index.index <= GEN_PARAMS.num_hash_indices);
ASSERT(index.sub_index < GEN_PARAMS.num_generators_per_hash_index);
return *global_generator_data[GEN_PARAMS.hash_indices_generator_offset +
((index.index - 1) * GEN_PARAMS.num_generators_per_hash_index) + index.sub_index];

// Handle user-defined generators
ASSERT(index.index <= num_hash_indices);
size_t global_index_offset = 0;
if (0 < index.index && index.index <= LOW.num_indices) {
// Calculate the global index of the generator for the LOW hash index
ASSERT(index.sub_index < LOW.num_generators_per_index);
const size_t local_index_offset = 0;
const size_t generator_count_offset = 0;
global_index_offset =
generator_count_offset + (index.index - local_index_offset - 1) * LOW.num_generators_per_index;

} else if (index.index <= (LOW.num_indices + MID.num_indices)) {
// Calculate the global index of the generator for the MID hash index
ASSERT(index.sub_index < MID.num_generators_per_index);
const size_t local_index_offset = LOW.num_indices;
const size_t generator_count_offset = LOW.total_generators();
global_index_offset =
generator_count_offset + (index.index - local_index_offset - 1) * MID.num_generators_per_index;

} else if (index.index <= (LOW.num_indices + MID.num_indices + HIGH.num_indices)) {
// Calculate the global index of the generator for the HIGH hash index
const size_t local_index_offset = LOW.num_indices + MID.num_indices;
const size_t generator_count_offset = LOW.total_generators() + MID.total_generators();
ASSERT(index.sub_index < HIGH.num_generators_per_index);
global_index_offset =
generator_count_offset + (index.index - local_index_offset - 1) * HIGH.num_generators_per_index;

} else {
// Throw an exception for invalid index values
throw_or_abort(format("invalid hash index: ", index.index));
}

// Return a reference to the user-defined generator with the specified index and sub-index
return *global_generator_data[num_default_generators + global_index_offset + index.sub_index];
}

const fixed_base_ladder* generator_data::get_ladder(size_t num_bits) const
Expand All @@ -290,7 +322,7 @@ const fixed_base_ladder* generator_data::get_ladder(size_t num_bits) const
const fixed_base_ladder* generator_data::get_hash_ladder(size_t num_bits) const
{
init_generator_data();
return get_ladder_internal(hash_ladder, num_bits);
return get_ladder_internal(ladder, num_bits, aux_length);
}

} // namespace generators
Expand Down
5 changes: 2 additions & 3 deletions cpp/src/barretenberg/crypto/generators/generator_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ struct fixed_base_ladder {
*/
constexpr size_t bit_length = 256;
constexpr size_t quad_length = bit_length / 2 + 1;
typedef std::array<fixed_base_ladder, quad_length> ladder_t;
constexpr size_t aux_length = 2;
typedef std::array<fixed_base_ladder, quad_length + aux_length> ladder_t;

struct generator_data {
grumpkin::g1::affine_element generator;
grumpkin::g1::affine_element aux_generator;
grumpkin::g1::affine_element skew_generator;
ladder_t ladder;
ladder_t aux_ladder;
ladder_t hash_ladder;

const fixed_base_ladder* get_ladder(size_t num_bits) const;
const fixed_base_ladder* get_hash_ladder(size_t num_bits) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ TEST_F(join_split_tests, test_0_input_notes_and_detect_circuit_change)

constexpr uint32_t CIRCUIT_GATE_COUNT = 184517;
constexpr uint32_t GATES_NEXT_POWER_OF_TWO = 524288;
const uint256_t VK_HASH("24999463fd4168e633aad6171f8538e2e344e9136c3284f95bf607850a7f79bd");
const uint256_t VK_HASH("787c464414a2c2e3332314ff528bd236b13133c269c5704505a0f3a3ad56ad57");

auto number_of_gates_js = result.number_of_gates;
std::cout << get_verification_key()->sha256_hash() << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion ts/src/barretenberg_api/pedersen.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe('pedersen', () => {

it('pedersenCompressWithHashIndex', () => {
const result = api.pedersenCompressWithHashIndex([new Fr(4n), new Fr(8n)], 7);
expect(result).toEqual(new Fr(12675961871866002745031098923411501942277744385859978302365013982702509949754n));
expect(result).toEqual(new Fr(11068631634751286805527305272746775861010877976108429785597565355072506728435n));
});

it('pedersenCommit', () => {
Expand Down

0 comments on commit d60b16a

Please sign in to comment.