diff --git a/Cargo.lock b/Cargo.lock index e82d47d690a..6e367788c7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,6 +31,7 @@ dependencies = [ "ark-bls12-381", "ark-bn254 0.5.0", "ark-ff 0.5.0", + "ark-std 0.5.0", "cfg-if", "hex", "num-bigint", diff --git a/acvm-repo/acir_field/Cargo.toml b/acvm-repo/acir_field/Cargo.toml index f9ae3a4ea3f..fea0775e10c 100644 --- a/acvm-repo/acir_field/Cargo.toml +++ b/acvm-repo/acir_field/Cargo.toml @@ -23,6 +23,7 @@ serde.workspace = true ark-bn254.workspace = true ark-bls12-381 = { workspace = true, optional = true } ark-ff.workspace = true +ark-std.workspace = true cfg-if.workspace = true diff --git a/acvm-repo/acir_field/src/field_element.rs b/acvm-repo/acir_field/src/field_element.rs index 0249b410aa7..e53fb760476 100644 --- a/acvm-repo/acir_field/src/field_element.rs +++ b/acvm-repo/acir_field/src/field_element.rs @@ -1,5 +1,6 @@ use ark_ff::PrimeField; use ark_ff::Zero; +use ark_std::io::Write; use num_bigint::BigUint; use serde::{Deserialize, Serialize}; use std::borrow::Cow; @@ -195,26 +196,9 @@ impl AcirField for FieldElement { /// This is the number of bits required to represent this specific field element fn num_bits(&self) -> u32 { - let bytes = self.to_be_bytes(); - - // Iterate through the byte decomposition and pop off all leading zeroes - let mut iter = bytes.iter().skip_while(|x| (**x) == 0); - - // The first non-zero byte in the decomposition may have some leading zero-bits. - let Some(head_byte) = iter.next() else { - // If we don't have a non-zero byte then the field element is zero, - // which we consider to require a single bit to represent. - return 1; - }; - let num_bits_for_head_byte = head_byte.ilog2(); - - // Each remaining byte in the byte decomposition requires 8 bits. - // - // Note: count will panic if it goes over usize::MAX. - // This may not be suitable for devices whose usize < u16 - let tail_length = iter.count() as u32; - - 8 * tail_length + num_bits_for_head_byte + 1 + let mut bit_counter = BitCounter::default(); + self.0.serialize_uncompressed(&mut bit_counter).unwrap(); + bit_counter.bits() } fn to_u128(self) -> u128 { @@ -354,6 +338,52 @@ impl SubAssign for FieldElement { } } +#[derive(Default, Debug)] +struct BitCounter { + /// Total number of non-zero bytes we found. + count: usize, + /// Total bytes we found. + total: usize, + /// The last non-zero byte we found. + head_byte: u8, +} + +impl BitCounter { + fn bits(&self) -> u32 { + // If we don't have a non-zero byte then the field element is zero, + // which we consider to require a single bit to represent. + if self.count == 0 { + return 1; + } + + let num_bits_for_head_byte = self.head_byte.ilog2(); + + // Each remaining byte in the byte decomposition requires 8 bits. + // + // Note: count will panic if it goes over usize::MAX. + // This may not be suitable for devices whose usize < u16 + let tail_length = (self.count - 1) as u32; + 8 * tail_length + num_bits_for_head_byte + 1 + } +} + +impl Write for BitCounter { + fn write(&mut self, buf: &[u8]) -> ark_std::io::Result { + for byte in buf { + self.total += 1; + if *byte != 0 { + self.count = self.total; + self.head_byte = *byte; + } + } + Ok(buf.len()) + } + + fn flush(&mut self) -> ark_std::io::Result<()> { + Ok(()) + } +} + #[cfg(test)] mod tests { use super::{AcirField, FieldElement};