Skip to content

Commit a3f15e4

Browse files
authored
Booth encoding (#106)
* booth encoding baseline * working msm with booth encoding * tidy * apply suggestions & remove leftovers
1 parent 2e7f8eb commit a3f15e4

File tree

1 file changed

+272
-26
lines changed

1 file changed

+272
-26
lines changed

src/msm.rs

+272-26
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,55 @@
1+
use std::ops::Neg;
2+
13
use ff::PrimeField;
24
use group::Group;
35
use pasta_curves::arithmetic::CurveAffine;
46

57
use crate::multicore;
68

9+
fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
10+
// Booth encoding:
11+
// * step by `window` size
12+
// * slice by size of `window + 1``
13+
// * each window overlap by 1 bit
14+
// * append a zero bit to the least significant end
15+
// Indexing rule for example window size 3 where we slice by 4 bits:
16+
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
17+
// So we can reduce the bucket size without preprocessing scalars
18+
// and remembering them as in classic signed digit encoding
19+
20+
let skip_bits = (window_index * window_size).saturating_sub(1);
21+
let skip_bytes = skip_bits / 8;
22+
23+
// fill into a u32
24+
let mut v: [u8; 4] = [0; 4];
25+
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
26+
*dst = *src
27+
}
28+
let mut tmp = u32::from_le_bytes(v);
29+
30+
// pad with one 0 if slicing the least significant window
31+
if window_index == 0 {
32+
tmp <<= 1;
33+
}
34+
35+
// remove further bits
36+
tmp >>= skip_bits - (skip_bytes * 8);
37+
// apply the booth window
38+
tmp &= (1 << (window_size + 1)) - 1;
39+
40+
let sign = tmp & (1 << window_size) == 0;
41+
42+
// div ceil by 2
43+
tmp = (tmp + 1) >> 1;
44+
45+
// find the booth action index
46+
if sign {
47+
tmp as i32
48+
} else {
49+
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
50+
}
51+
}
52+
753
pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
854
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
955

@@ -15,29 +61,9 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
1561
(f64::from(bases.len() as u32)).ln().ceil() as usize
1662
};
1763

18-
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
19-
let skip_bits = segment * c;
20-
let skip_bytes = skip_bits / 8;
21-
22-
if skip_bytes >= 32 {
23-
return 0;
24-
}
25-
26-
let mut v = [0; 8];
27-
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
28-
*v = *o;
29-
}
64+
let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
3065

31-
let mut tmp = u64::from_le_bytes(v);
32-
tmp >>= skip_bits - (skip_bytes * 8);
33-
tmp %= 1 << c;
34-
35-
tmp as usize
36-
}
37-
38-
let segments = (256 / c) + 1;
39-
40-
for current_segment in (0..segments).rev() {
66+
for current_window in (0..number_of_windows).rev() {
4167
for _ in 0..c {
4268
*acc = acc.double();
4369
}
@@ -73,12 +99,15 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
7399
}
74100
}
75101

76-
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
102+
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];
77103

78104
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
79-
let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
80-
if coeff != 0 {
81-
buckets[coeff - 1].add_assign(base);
105+
let coeff = get_booth_index(current_window as usize, c, coeff.as_ref());
106+
if coeff.is_positive() {
107+
buckets[coeff as usize - 1].add_assign(base);
108+
}
109+
if coeff.is_negative() {
110+
buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
82111
}
83112
}
84113

@@ -151,3 +180,220 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
151180
acc
152181
}
153182
}
183+
184+
#[cfg(test)]
185+
mod test {
186+
187+
use std::ops::Neg;
188+
189+
use crate::{
190+
bn256::{Fr, G1Affine, G1},
191+
multicore,
192+
};
193+
use ark_std::{end_timer, start_timer};
194+
use ff::{Field, PrimeField};
195+
use group::{Curve, Group};
196+
use pasta_curves::arithmetic::CurveAffine;
197+
use rand_core::OsRng;
198+
199+
// keeping older implementation it here for baseline comparision, debugging & benchmarking
200+
fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
201+
assert_eq!(coeffs.len(), bases.len());
202+
203+
let num_threads = multicore::current_num_threads();
204+
if coeffs.len() > num_threads {
205+
let chunk = coeffs.len() / num_threads;
206+
let num_chunks = coeffs.chunks(chunk).len();
207+
let mut results = vec![C::Curve::identity(); num_chunks];
208+
multicore::scope(|scope| {
209+
let chunk = coeffs.len() / num_threads;
210+
211+
for ((coeffs, bases), acc) in coeffs
212+
.chunks(chunk)
213+
.zip(bases.chunks(chunk))
214+
.zip(results.iter_mut())
215+
{
216+
scope.spawn(move |_| {
217+
multiexp_serial(coeffs, bases, acc);
218+
});
219+
}
220+
});
221+
results.iter().fold(C::Curve::identity(), |a, b| a + b)
222+
} else {
223+
let mut acc = C::Curve::identity();
224+
multiexp_serial(coeffs, bases, &mut acc);
225+
acc
226+
}
227+
}
228+
229+
// keeping older implementation it here for baseline comparision, debugging & benchmarking
230+
fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
231+
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
232+
233+
let c = if bases.len() < 4 {
234+
1
235+
} else if bases.len() < 32 {
236+
3
237+
} else {
238+
(f64::from(bases.len() as u32)).ln().ceil() as usize
239+
};
240+
241+
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
242+
let skip_bits = segment * c;
243+
let skip_bytes = skip_bits / 8;
244+
245+
if skip_bytes >= 32 {
246+
return 0;
247+
}
248+
249+
let mut v = [0; 8];
250+
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
251+
*v = *o;
252+
}
253+
254+
let mut tmp = u64::from_le_bytes(v);
255+
tmp >>= skip_bits - (skip_bytes * 8);
256+
tmp %= 1 << c;
257+
258+
tmp as usize
259+
}
260+
261+
let segments = (256 / c) + 1;
262+
263+
for current_segment in (0..segments).rev() {
264+
for _ in 0..c {
265+
*acc = acc.double();
266+
}
267+
268+
#[derive(Clone, Copy)]
269+
enum Bucket<C: CurveAffine> {
270+
None,
271+
Affine(C),
272+
Projective(C::Curve),
273+
}
274+
275+
impl<C: CurveAffine> Bucket<C> {
276+
fn add_assign(&mut self, other: &C) {
277+
*self = match *self {
278+
Bucket::None => Bucket::Affine(*other),
279+
Bucket::Affine(a) => Bucket::Projective(a + *other),
280+
Bucket::Projective(mut a) => {
281+
a += *other;
282+
Bucket::Projective(a)
283+
}
284+
}
285+
}
286+
287+
fn add(self, mut other: C::Curve) -> C::Curve {
288+
match self {
289+
Bucket::None => other,
290+
Bucket::Affine(a) => {
291+
other += a;
292+
other
293+
}
294+
Bucket::Projective(a) => other + a,
295+
}
296+
}
297+
}
298+
299+
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
300+
301+
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
302+
let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
303+
if coeff != 0 {
304+
buckets[coeff - 1].add_assign(base);
305+
}
306+
}
307+
308+
// Summation by parts
309+
// e.g. 3a + 2b + 1c = a +
310+
// (a) + b +
311+
// ((a) + b) + c
312+
let mut running_sum = C::Curve::identity();
313+
for exp in buckets.into_iter().rev() {
314+
running_sum = exp.add(running_sum);
315+
*acc += &running_sum;
316+
}
317+
}
318+
}
319+
320+
#[test]
321+
fn test_booth_encoding() {
322+
fn mul(scalar: &Fr, point: &G1Affine, window: usize) -> G1Affine {
323+
let u = scalar.to_repr();
324+
let n = Fr::NUM_BITS as usize / window + 1;
325+
326+
let table = (0..=1 << (window - 1))
327+
.map(|i| point * Fr::from(i as u64))
328+
.collect::<Vec<_>>();
329+
330+
let mut acc = G1::identity();
331+
for i in (0..n).rev() {
332+
for _ in 0..window {
333+
acc = acc.double();
334+
}
335+
336+
let idx = super::get_booth_index(i as usize, window, u.as_ref());
337+
338+
if idx.is_negative() {
339+
acc += table[idx.unsigned_abs() as usize].neg();
340+
}
341+
if idx.is_positive() {
342+
acc += table[idx.unsigned_abs() as usize];
343+
}
344+
}
345+
346+
acc.to_affine()
347+
}
348+
349+
let (scalars, points): (Vec<_>, Vec<_>) = (0..10)
350+
.map(|_| {
351+
let scalar = Fr::random(OsRng);
352+
let point = G1Affine::random(OsRng);
353+
(scalar, point)
354+
})
355+
.unzip();
356+
357+
for window in 1..10 {
358+
for (scalar, point) in scalars.iter().zip(points.iter()) {
359+
let c0 = mul(scalar, point, window);
360+
let c1 = point * scalar;
361+
assert_eq!(c0, c1.to_affine());
362+
}
363+
}
364+
}
365+
366+
fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
367+
let points = (0..1 << max_k)
368+
.map(|_| C::Curve::random(OsRng))
369+
.collect::<Vec<_>>();
370+
let mut affine_points = vec![C::identity(); 1 << max_k];
371+
C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
372+
let points = affine_points;
373+
374+
let scalars = (0..1 << max_k)
375+
.map(|_| C::Scalar::random(OsRng))
376+
.collect::<Vec<_>>();
377+
378+
for k in min_k..=max_k {
379+
let points = &points[..1 << k];
380+
let scalars = &scalars[..1 << k];
381+
382+
let t0 = start_timer!(|| format!("w/ booth k={}", k));
383+
let e0 = super::best_multiexp(scalars, points);
384+
end_timer!(t0);
385+
386+
let t1 = start_timer!(|| format!("w/o booth k={}", k));
387+
let e1 = best_multiexp(scalars, points);
388+
end_timer!(t1);
389+
390+
assert_eq!(e0, e1);
391+
}
392+
}
393+
394+
#[test]
395+
fn test_msm_cross() {
396+
run_msm_cross::<G1Affine>(10, 18);
397+
// run_msm_cross::<G1Affine>(19, 23);
398+
}
399+
}

0 commit comments

Comments
 (0)